修改 websocket的支持
This commit is contained in:
parent
d10f404206
commit
2a2b5af0ca
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -53,3 +53,6 @@ server/product-template/product-template
|
|||
server/shopping-cart-confirmation/shopping-cart-confirmation
|
||||
server/upload/upload
|
||||
server/webset/webset
|
||||
|
||||
|
||||
shared-state
|
|
@ -3,7 +3,7 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"io/fs"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -15,6 +15,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
|
@ -115,6 +116,8 @@ type Backend struct {
|
|||
HttpAddress string
|
||||
Client *http.Client
|
||||
Handler http.HandlerFunc
|
||||
|
||||
Dialer *websocket.Dialer
|
||||
}
|
||||
|
||||
func NewBackend(mux *http.ServeMux, httpAddress string, muxPaths ...string) *Backend {
|
||||
|
@ -142,14 +145,29 @@ func NewBackend(mux *http.ServeMux, httpAddress string, muxPaths ...string) *Bac
|
|||
},
|
||||
}
|
||||
|
||||
dialer := &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
NetDial: func(network, addr string) (net.Conn, error) {
|
||||
return net.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
|
||||
// 创建后端服务对象,包含地址和客户端
|
||||
backend := &Backend{
|
||||
HttpAddress: httpAddress,
|
||||
Client: client,
|
||||
Dialer: dialer,
|
||||
}
|
||||
|
||||
// 创建处理请求的函数
|
||||
handleRequest := func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if websocket.IsWebSocketUpgrade(r) {
|
||||
// Handle websocket connections
|
||||
handleWebSocketProxy(w, r, backend)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析目标URL,包含了查询参数
|
||||
targetURL, err := url.Parse(httpAddress + r.URL.String())
|
||||
if err != nil {
|
||||
|
@ -226,7 +244,7 @@ type Result struct {
|
|||
|
||||
// GetZeroInfo 遍历指定目录,并解析相关信息
|
||||
func GetZeroInfo(rootDir string) (results []*Result) {
|
||||
entries, err := ioutil.ReadDir(rootDir)
|
||||
entries, err := os.ReadDir(rootDir)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -247,7 +265,7 @@ func GetZeroInfo(rootDir string) (results []*Result) {
|
|||
}
|
||||
|
||||
// findFoldersAndExtractInfo 查找目录并提取信息
|
||||
func findFoldersAndExtractInfo(rootDir string, entry os.FileInfo) (*Result, error) {
|
||||
func findFoldersAndExtractInfo(rootDir string, entry fs.DirEntry) (*Result, error) {
|
||||
var result *Result
|
||||
|
||||
folderName := entry.Name()
|
||||
|
@ -277,7 +295,7 @@ func findFoldersAndExtractInfo(rootDir string, entry os.FileInfo) (*Result, erro
|
|||
configPath := filepath.Join(path, "etc", folderName+".yaml")
|
||||
routesPath := filepath.Join(path, "internal", "handler", "routes.go")
|
||||
|
||||
configContent, err := ioutil.ReadFile(configPath)
|
||||
configContent, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -289,7 +307,7 @@ func findFoldersAndExtractInfo(rootDir string, entry os.FileInfo) (*Result, erro
|
|||
}
|
||||
|
||||
// 读取路由文件
|
||||
routesContent, err := ioutil.ReadFile(routesPath)
|
||||
routesContent, err := os.ReadFile(routesPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -334,3 +352,41 @@ func extractPrefixRouteValues(content string) map[string]bool {
|
|||
|
||||
return prefixPath
|
||||
}
|
||||
|
||||
func handleWebSocketProxy(w http.ResponseWriter, r *http.Request, backend *Backend) {
|
||||
target := url.URL{Scheme: "ws", Host: backend.HttpAddress, Path: r.URL.Path}
|
||||
|
||||
proxyConn, _, err := backend.Dialer.DialContext(r.Context(), target.String(), nil)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer proxyConn.Close()
|
||||
|
||||
upgrader := websocket.Upgrader{}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
go transfer(proxyConn, conn)
|
||||
go transfer(conn, proxyConn)
|
||||
}
|
||||
|
||||
func transfer(src, dest *websocket.Conn) {
|
||||
for {
|
||||
messageType, data, err := src.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
err = dest.WriteMessage(messageType, data)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
src.Close()
|
||||
dest.Close()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user