修改 websocket的支持

This commit is contained in:
eson 2023-08-10 16:44:45 +08:00
parent d10f404206
commit 2a2b5af0ca
2 changed files with 65 additions and 6 deletions

3
.gitignore vendored
View File

@ -53,3 +53,6 @@ server/product-template/product-template
server/shopping-cart-confirmation/shopping-cart-confirmation
server/upload/upload
server/webset/webset
shared-state

View File

@ -3,7 +3,7 @@ package main
import (
"fmt"
"io"
"io/ioutil"
"io/fs"
"log"
"net"
"net/http"
@ -15,6 +15,7 @@ import (
"sync"
"time"
"github.com/gorilla/websocket"
"gopkg.in/yaml.v2"
)
@ -115,6 +116,8 @@ type Backend struct {
HttpAddress string
Client *http.Client
Handler http.HandlerFunc
Dialer *websocket.Dialer
}
func NewBackend(mux *http.ServeMux, httpAddress string, muxPaths ...string) *Backend {
@ -142,14 +145,29 @@ func NewBackend(mux *http.ServeMux, httpAddress string, muxPaths ...string) *Bac
},
}
dialer := &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
NetDial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
}
// 创建后端服务对象,包含地址和客户端
backend := &Backend{
HttpAddress: httpAddress,
Client: client,
Dialer: dialer,
}
// 创建处理请求的函数
handleRequest := func(w http.ResponseWriter, r *http.Request) {
if websocket.IsWebSocketUpgrade(r) {
// Handle websocket connections
handleWebSocketProxy(w, r, backend)
return
}
// 解析目标URL包含了查询参数
targetURL, err := url.Parse(httpAddress + r.URL.String())
if err != nil {
@ -226,7 +244,7 @@ type Result struct {
// GetZeroInfo 遍历指定目录,并解析相关信息
func GetZeroInfo(rootDir string) (results []*Result) {
entries, err := ioutil.ReadDir(rootDir)
entries, err := os.ReadDir(rootDir)
if err != nil {
log.Fatal(err)
}
@ -247,7 +265,7 @@ func GetZeroInfo(rootDir string) (results []*Result) {
}
// findFoldersAndExtractInfo 查找目录并提取信息
func findFoldersAndExtractInfo(rootDir string, entry os.FileInfo) (*Result, error) {
func findFoldersAndExtractInfo(rootDir string, entry fs.DirEntry) (*Result, error) {
var result *Result
folderName := entry.Name()
@ -277,7 +295,7 @@ func findFoldersAndExtractInfo(rootDir string, entry os.FileInfo) (*Result, erro
configPath := filepath.Join(path, "etc", folderName+".yaml")
routesPath := filepath.Join(path, "internal", "handler", "routes.go")
configContent, err := ioutil.ReadFile(configPath)
configContent, err := os.ReadFile(configPath)
if err != nil {
return err
}
@ -289,7 +307,7 @@ func findFoldersAndExtractInfo(rootDir string, entry os.FileInfo) (*Result, erro
}
// 读取路由文件
routesContent, err := ioutil.ReadFile(routesPath)
routesContent, err := os.ReadFile(routesPath)
if err != nil {
return err
}
@ -334,3 +352,41 @@ func extractPrefixRouteValues(content string) map[string]bool {
return prefixPath
}
func handleWebSocketProxy(w http.ResponseWriter, r *http.Request, backend *Backend) {
target := url.URL{Scheme: "ws", Host: backend.HttpAddress, Path: r.URL.Path}
proxyConn, _, err := backend.Dialer.DialContext(r.Context(), target.String(), nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer proxyConn.Close()
upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
go transfer(proxyConn, conn)
go transfer(conn, proxyConn)
}
func transfer(src, dest *websocket.Conn) {
for {
messageType, data, err := src.ReadMessage()
if err != nil {
break
}
err = dest.WriteMessage(messageType, data)
if err != nil {
break
}
}
src.Close()
dest.Close()
}