fusenapi/server/websocket/internal/logic/datatransferlogic.go
laodaming a8457b7dad fix
2023-08-23 16:13:14 +08:00

355 lines
9.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package logic
//websocket连接
import (
"bytes"
"encoding/hex"
"encoding/json"
"fusenapi/constants"
"fusenapi/server/websocket/internal/websocket_data"
"fusenapi/utils/auth"
"net/http"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"context"
"fusenapi/server/websocket/internal/svc"
"github.com/zeromicro/go-zero/core/logx"
)
type DataTransferLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewDataTransferLogic(ctx context.Context, svcCtx *svc.ServiceContext) *DataTransferLogic {
return &DataTransferLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
var (
//临时缓存对象池
buffPool = sync.Pool{
New: func() interface{} {
return bytes.Buffer{}
},
}
//升级websocket
upgrader = websocket.Upgrader{
//最大可读取大小 10M
ReadBufferSize: 1024 * 10,
//握手超时时间15s
HandshakeTimeout: time.Second * 15,
//允许跨域
CheckOrigin: func(r *http.Request) bool {
return true
},
//写的缓存池
WriteBufferPool: &buffPool,
//是否支持压缩
EnableCompression: true,
}
//websocket连接存储
mapConnPool = sync.Map{}
//公共互斥锁(复用连接标识用)
publicMutex sync.Mutex
)
// 每个连接的连接基本属性
type wsConnectItem struct {
conn *websocket.Conn //websocket的连接(基本属性)
logic *DataTransferLogic //logic(基本属性,用于获取上下文,配置或者操作数据库)
closeChan chan struct{} //ws连接关闭chan(基本属性)
isClose bool //是否已经关闭(基本属性)
uniqueId string //ws连接唯一标识(基本属性)
inChan chan []byte //接受消息缓冲池(基本属性)
outChan chan []byte //要发送回客户端的消息缓冲池(基本属性)
mutex sync.Mutex //互斥锁(关闭连接方法中用)
userId int64 //用户id(基本属性)
guestId int64 //游客id(基本属性)
renderProperty renderProperty //扩展云渲染属性(扩展属性)
}
// 请求建立连接升级websocket协议
func (l *DataTransferLogic) DataTransfer(w http.ResponseWriter, r *http.Request) {
//把子协议携带的token设置到标准token头信息中
token := r.Header.Get("Sec-Websocket-Protocol")
r.Header.Set("Authorization", "Bearer "+token)
//设置Sec-Websocket-Protocol
upgrader.Subprotocols = []string{token}
//升级websocket
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logx.Error("http upgrade websocket err:", err)
return
}
defer conn.Close()
//鉴权不成功后断开
var (
userInfo *auth.UserInfo
isAuth bool
)
isAuth, userInfo = l.checkAuth(r)
if !isAuth {
//未授权响应消息
l.unAuthResponse(conn)
return
}
//设置连接
ws := l.setConnPool(conn, *userInfo)
defer ws.close()
//循环读客户端信息
go ws.readLoop()
//循环把数据发送给客户端
go ws.writeLoop()
//推消息到云渲染
go ws.sendLoop()
//操作连接中渲染任务的增加/删除
go ws.operationRenderTask()
//消费渲染缓冲队列
go ws.renderImage()
//心跳
ws.heartbeat()
}
// 设置连接
func (l *DataTransferLogic) setConnPool(conn *websocket.Conn, userInfo auth.UserInfo) wsConnectItem {
publicMutex.Lock()
defer publicMutex.Unlock()
//生成连接唯一标识
uniqueId := l.getUniqueId(userInfo)
ws := wsConnectItem{
conn: conn,
logic: l,
uniqueId: uniqueId,
closeChan: make(chan struct{}, 1),
inChan: make(chan []byte, 1000),
outChan: make(chan []byte, 1000),
userId: userInfo.UserId,
guestId: userInfo.GuestId,
renderProperty: renderProperty{
renderImageTask: make(map[string]*renderTask),
renderImageTaskCtlChan: make(chan renderImageControlChanItem, 500),
renderChan: make(chan []byte, 500),
},
}
//保存连接
mapConnPool.Store(uniqueId, ws)
go func() {
defer func() {
if err := recover(); err != nil {
logx.Error("set conn pool panic:", err)
}
}()
//把连接成功消息发回去
time.Sleep(time.Second * 1) //兼容下火狐(直接发回去收不到第一条消息:有待研究)
ws.sendToOutChan(ws.respondDataFormat(constants.WEBSOCKET_CONNECT_SUCCESS, uniqueId))
}()
return ws
}
// 获取唯一id
func (l *DataTransferLogic) getUniqueId(userInfo auth.UserInfo) string {
//后面拼接上用户id
uniqueId := hex.EncodeToString([]byte(uuid.New().String())) + getUserJoinPart(userInfo.UserId, userInfo.GuestId)
if _, ok := mapConnPool.Load(uniqueId); ok {
uniqueId = l.getUniqueId(userInfo)
}
return uniqueId
}
// 鉴权
func (l *DataTransferLogic) checkAuth(r *http.Request) (isAuth bool, userInfo *auth.UserInfo) {
// 解析JWT token,并对空用户进行判断
claims, err := l.svcCtx.ParseJwtToken(r)
// 如果解析JWT token出错,则返回未授权的JSON响应并记录错误消息
if err != nil {
logx.Error(err)
return false, nil
}
if claims != nil {
// 从token中获取对应的用户信息
userInfo, err = auth.GetUserInfoFormMapClaims(claims)
// 如果获取用户信息出错,则返回未授权的JSON响应并记录错误消息
if err != nil {
logx.Error(err)
return false, nil
}
//不是登录用户也不是游客
if !userInfo.IsUser() && !userInfo.IsGuest() {
return false, nil
}
return true, userInfo
}
return false, nil
}
// 鉴权失败通知
func (l *DataTransferLogic) unAuthResponse(conn *websocket.Conn) {
time.Sleep(time.Second * 1) //兼容下火狐(直接发回去收不到第一条消息:有待研究)
rsp := websocket_data.DataTransferData{
T: constants.WEBSOCKET_UNAUTH,
D: nil,
}
b, _ := json.Marshal(rsp)
//先发一条正常信息
_ = conn.WriteMessage(websocket.TextMessage, b)
//发送关闭信息
_ = conn.WriteMessage(websocket.CloseMessage, nil)
//关闭连接
conn.Close()
}
// 心跳检测
func (w *wsConnectItem) heartbeat() {
tick := time.Tick(time.Second * 5)
for {
select {
case <-w.closeChan:
return
case <-tick:
//发送心跳信息
if err := w.conn.WriteMessage(websocket.PongMessage, nil); err != nil {
logx.Error("发送心跳信息异常,关闭连接:", w.uniqueId, err)
w.close()
return
}
}
}
}
// 关闭websocket连接
func (w *wsConnectItem) close() {
w.mutex.Lock()
defer w.mutex.Unlock()
logx.Info("websocket:", w.uniqueId, " is closing...")
//发送关闭信息
_ = w.conn.WriteMessage(websocket.CloseMessage, nil)
w.conn.Close()
mapConnPool.Delete(w.uniqueId)
if !w.isClose {
w.isClose = true
close(w.closeChan)
}
logx.Info("websocket:", w.uniqueId, " is closed")
}
// 读取出口缓冲池数据输出返回给浏览器端
func (w *wsConnectItem) writeLoop() {
defer func() {
if err := recover(); err != nil {
logx.Error("write loop panic:", err)
}
}()
for {
select {
case <-w.closeChan: //如果关闭了
return
case data := <-w.outChan:
if err := w.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logx.Error("websocket write loop err:", err)
w.close()
return
}
}
}
}
// 接受客户端发来的消息并写入入口缓冲池
func (w *wsConnectItem) readLoop() {
defer func() {
if err := recover(); err != nil {
logx.Error("read loop panic:", err)
}
}()
for {
select {
case <-w.closeChan: //如果关闭了
return
default:
msgType, data, err := w.conn.ReadMessage()
if err != nil {
logx.Error("接受信息错误:", err)
//关闭连接
w.close()
return
}
//ping的消息不处理
if msgType != websocket.PingMessage {
//消息传入缓冲通道
w.inChan <- data
}
}
}
}
// 消费websocket入口数据池中的数据
func (w *wsConnectItem) sendLoop() {
defer func() {
if err := recover(); err != nil {
logx.Error("send loop panic:", err)
}
}()
for {
select {
case <-w.closeChan:
return
case data := <-w.inChan:
w.dealwithReciveData(data)
}
}
}
// 把要传递给客户端的数据放入出口缓冲池
func (w *wsConnectItem) sendToOutChan(data []byte) {
select {
case <-w.closeChan:
return
case w.outChan <- data:
return
case <-time.After(time.Second * 3): //阻塞超过3秒丢弃
return
}
}
// 格式化为websocket标准返回格式
func (w *wsConnectItem) respondDataFormat(msgType constants.Websocket, data interface{}) []byte {
d := websocket_data.DataTransferData{
T: msgType,
D: data,
}
b, _ := json.Marshal(d)
return b
}
// 处理入口缓冲池中不同类型的数据(分发处理)
func (w *wsConnectItem) dealwithReciveData(data []byte) {
var parseInfo websocket_data.DataTransferData
if err := json.Unmarshal(data, &parseInfo); err != nil {
logx.Error("invalid format of websocket message:", err)
w.sendToOutChan(w.respondDataFormat(constants.WEBSOCKET_ERR_DATA_FORMAT, "invalid format of websocket message:"+string(data)))
return
}
d, _ := json.Marshal(parseInfo.D)
//分消息类型给到不同逻辑处理,可扩展
switch parseInfo.T {
//图片渲染
case constants.WEBSOCKET_RENDER_IMAGE:
w.sendToRenderChan(d)
//刷新重连请求恢复上次连接的标识
case constants.WEBSOCKET_REQUEST_REUSE_LAST_CONNECT:
w.reuseLastConnect(d)
default:
logx.Error("未知消息类型:", parseInfo.T)
}
}