diff --git a/constants/websocket.go b/constants/websocket.go index 5db15ad8..4fd14ff5 100644 --- a/constants/websocket.go +++ b/constants/websocket.go @@ -4,11 +4,9 @@ type Websocket string // websocket消息类型(主类别) const ( - WEBSOCKET_UNAUTH Websocket = "WEBSOCKET_UNAUTH" //鉴权失败 (1级消息,单向通信) - WEBSOCKET_CONNECT_ERR Websocket = "WEBSOCKET_CONNECT_ERR" //ws连接错误 (1级消息,单向通信) - WEBSOCKET_CONNECT_SUCCESS Websocket = "WEBSOCKET_CONNECT_SUCCESS" //ws连接成功 (1级消息,单向通信) - WEBSOCKET_REQUEST_REUSE_LAST_CONNECT Websocket = "WEBSOCKET_REQUEST_REUSE_LAST_CONNECT" //请求恢复为上次连接的标识 (1级消息,单向通信) - WEBSOCKET_REQUEST_RESUME_LAST_CONNECT_ERR Websocket = "WEBSOCKET_REQUEST_RESUME_LAST_CONNECT_ERR" //请求恢复为上次连接的标识错误 (1级消息,单向通信) + WEBSOCKET_UNAUTH Websocket = "WEBSOCKET_UNAUTH" //鉴权失败 (1级消息,单向通信) + WEBSOCKET_CONNECT_ERR Websocket = "WEBSOCKET_CONNECT_ERR" //ws连接错误 (1级消息,单向通信) + WEBSOCKET_CONNECT_SUCCESS Websocket = "WEBSOCKET_CONNECT_SUCCESS" //ws连接成功 (1级消息,单向通信) ) // websocket消息类型(通用通知类别) diff --git a/server/websocket/internal/logic/datatransferlogic.go b/server/websocket/internal/logic/datatransferlogic.go index ac99c474..60d59d24 100644 --- a/server/websocket/internal/logic/datatransferlogic.go +++ b/server/websocket/internal/logic/datatransferlogic.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "encoding/json" "errors" + "fmt" "fusenapi/constants" "fusenapi/utils/auth" "fusenapi/utils/basic" @@ -118,12 +119,25 @@ func (l *DataTransferLogic) DataTransfer(w http.ResponseWriter, r *http.Request) } } //把子协议携带的token设置到标准token头信息中 - token := r.Header.Get("Sec-Websocket-Protocol") + secWebsocketProtocol := r.Header.Get("Sec-Websocket-Protocol") + oldWid := "" //有token是正常用户,无则是白板用户,也可以连接 - if token != "" { - r.Header.Set("Authorization", "Bearer "+token) + if secWebsocketProtocol != "" { + s := strings.Split(secWebsocketProtocol, ",") + if len(s) != 2 { + w.Write([]byte("invalid secWebsocketProtocol param")) + return + } + //有效token + if s[0] != "empty_token" { + r.Header.Set("Authorization", "Bearer "+s[0]) + } + //有效wid + if s[1] != "empty_wid" { + oldWid = s[1] + } //设置Sec-Websocket-Protocol - upgrader.Subprotocols = []string{token} + upgrader.Subprotocols = []string{secWebsocketProtocol} } //判断下是否火狐浏览器(获取浏览器第一条消息返回有收不到的bug需要延迟1秒) userAgent := r.Header.Get("User-Agent") @@ -153,7 +167,7 @@ func (l *DataTransferLogic) DataTransfer(w http.ResponseWriter, r *http.Request) return } //设置连接 - ws, err := l.setConnPool(conn, userInfo, isFirefoxBrowser, userAgent) + ws, err := l.setConnPool(conn, userInfo, isFirefoxBrowser, userAgent, oldWid) if err != nil { conn.Close() return @@ -171,7 +185,7 @@ func (l *DataTransferLogic) DataTransfer(w http.ResponseWriter, r *http.Request) } // 设置连接 -func (l *DataTransferLogic) setConnPool(conn *websocket.Conn, userInfo *auth.UserInfo, isFirefoxBrowser bool, userAgent string) (wsConnectItem, error) { +func (l *DataTransferLogic) setConnPool(conn *websocket.Conn, userInfo *auth.UserInfo, isFirefoxBrowser bool, userAgent, oldWid string) (wsConnectItem, error) { //生成连接唯一标识(失败重试10次) uniqueId, err := l.getUniqueId(userInfo, userAgent, 10) if err != nil { @@ -179,6 +193,32 @@ func (l *DataTransferLogic) setConnPool(conn *websocket.Conn, userInfo *auth.Use l.sendGetUniqueIdErrResponse(conn) return wsConnectItem{}, err } + if oldWid != "" { + //解析传入的wid是不是属于自己的用户的 + decryptionWid, err := encryption_decryption.CBCDecrypt(oldWid) + if err != nil { + logx.Error(err) + return wsConnectItem{}, errors.New("解码wid失败") + } + lendecryptionWid := len(decryptionWid) + //合成client后缀,不是同个后缀的不能复用 + userPart := getUserJoinPart(userInfo.UserId, userInfo.GuestId, userAgent) + lenUserPart := len(userPart) + if lendecryptionWid <= lenUserPart { + return wsConnectItem{}, errors.New("length of client id is too short") + } + //尾部不同不能复用 + if decryptionWid[lendecryptionWid-lenUserPart:] != userPart { + return wsConnectItem{}, errors.New("the client id is not belong to you before") + } + //存在是不能给他申请重新绑定 + if _, ok := mapConnPool.Load(oldWid); ok { + return wsConnectItem{}, errors.New("the wid has bond by other connect") + } + //检测通过可以用旧的 + logx.Info("====复用旧的ws连接成功====") + uniqueId = oldWid + } renderCtx, renderCtxCancelFunc := context.WithCancel(l.ctx) ws := wsConnectItem{ conn: conn, @@ -423,6 +463,14 @@ func (w *wsConnectItem) respondDataFormat(msgType constants.Websocket, data inte return b } +// 获取用户拼接部分(复用标识用到) +func getUserJoinPart(userId, guestId int64, userAgent string) string { + if userId > 0 { + guestId = 0 + } + return fmt.Sprintf("|_%d_%d_|_%s_|", userId, guestId, userAgent) +} + // 处理入口缓冲队列中不同类型的数据(分发处理) func (w *wsConnectItem) allocationProcessing(data []byte) { var parseInfo websocket_data.DataTransferData diff --git a/server/websocket/internal/logic/ws_allocation_processing_factory.go b/server/websocket/internal/logic/ws_allocation_processing_factory.go index 7ab780ae..9ba718cc 100644 --- a/server/websocket/internal/logic/ws_allocation_processing_factory.go +++ b/server/websocket/internal/logic/ws_allocation_processing_factory.go @@ -21,9 +21,6 @@ func (w *wsConnectItem) newAllocationProcessor(msgType constants.Websocket) allo //图片渲染 case constants.WEBSOCKET_RENDER_IMAGE: obj = &renderProcessor{} - //刷新重连请求恢复上次连接的标识 - case constants.WEBSOCKET_REQUEST_REUSE_LAST_CONNECT: - obj = &reuseConnProcessor{} default: return nil } diff --git a/server/websocket/internal/logic/ws_err_response.go b/server/websocket/internal/logic/ws_err_response.go index 69e37145..01f5a9b2 100644 --- a/server/websocket/internal/logic/ws_err_response.go +++ b/server/websocket/internal/logic/ws_err_response.go @@ -38,8 +38,3 @@ func (w *wsConnectItem) renderErrResponse(renderId, requestId, templateTag, task } w.sendToOutChan(w.respondDataFormat(constants.WEBSOCKET_RENDER_IMAGE_ERR, data)) } - -// 复用连接错误通知 -func (w *wsConnectItem) reuseLastConnErrResponse(data interface{}) { - w.sendToOutChan(w.respondDataFormat(constants.WEBSOCKET_REQUEST_RESUME_LAST_CONNECT_ERR, data)) -} diff --git a/server/websocket/internal/logic/ws_reuse_last_connect.go b/server/websocket/internal/logic/ws_reuse_last_connect.go deleted file mode 100644 index 8dbe13ca..00000000 --- a/server/websocket/internal/logic/ws_reuse_last_connect.go +++ /dev/null @@ -1,82 +0,0 @@ -package logic - -//复用websocket连接标识 -import ( - "encoding/json" - "fmt" - "fusenapi/constants" - "fusenapi/utils/encryption_decryption" - "github.com/zeromicro/go-zero/core/logx" -) - -// 复用连接处理器 -type reuseConnProcessor struct { -} - -// 处理分发到这里的数据 -func (r *reuseConnProcessor) allocationMessage(w *wsConnectItem, data []byte) { - //logx.Info("收到请求恢复上次连接标识数据:", string(data)) - var wid string - if err := json.Unmarshal(data, &wid); err != nil { - logx.Error(" invalid format of wid :", wid) - w.incomeDataFormatErrResponse("invalid format of wid") - return - } - //解密 - decryptionWid, err := encryption_decryption.CBCDecrypt(wid) - if err != nil { - w.reuseLastConnErrResponse("invalid wid") - return - } - lendecryptionWid := len(decryptionWid) - //合成client后缀,不是同个后缀的不能复用 - userPart := getUserJoinPart(w.userId, w.guestId, w.userAgent) - lenUserPart := len(userPart) - if lendecryptionWid <= lenUserPart { - w.reuseLastConnErrResponse("length of client id is to short") - return - } - //尾部不同不能复用 - if decryptionWid[lendecryptionWid-lenUserPart:] != userPart { - w.reuseLastConnErrResponse("the client id is not belong to you before") - return - } - //存在是不能给他申请重新绑定 - if v, ok := mapConnPool.Load(wid); ok { - obj, ok := v.(wsConnectItem) - if !ok { - w.reuseLastConnErrResponse("连接断言失败") - logx.Error("连接断言失败") - return - } - //是当前自己占用(无需处理) - if obj.uniqueId == w.uniqueId { - rsp := w.respondDataFormat(constants.WEBSOCKET_CONNECT_SUCCESS, wid) - w.sendToOutChan(rsp) - return - } else { - w.reuseLastConnErrResponse("the wid is used by other people") - return - } - } - //重新绑定 - //logx.Info("开始重新绑定websocket连接标识") - oldUniqueId := w.uniqueId - w.uniqueId = wid - mapConnPool.Store(wid, *w) - //删除用户id级别之前的索引 - deleteUserConnPoolElement(w.userId, w.guestId, oldUniqueId) - //添加用户id级别索引 - createUserConnPoolElement(w.userId, w.guestId, wid) - rsp := w.respondDataFormat(constants.WEBSOCKET_CONNECT_SUCCESS, wid) - w.sendToOutChan(rsp) - //logx.Info("重新绑定websocket连接标识成功") -} - -// 获取用户拼接部分(复用标识用到) -func getUserJoinPart(userId, guestId int64, userAgent string) string { - if userId > 0 { - guestId = 0 - } - return fmt.Sprintf("|_%d_%d_|_%s_|", userId, guestId, userAgent) -}