package auth import ( "errors" "fmt" "github.com/golang-jwt/jwt" "github.com/zeromicro/go-zero/core/logx" ) type IDTYPE int const ( // 白板用户, 以观众身份命名, 没有接收Cookie, 没有拿到guest_id的用户 IDTYPE_Onlooker IDTYPE = 0 // 登录用户 IDTYPE_User IDTYPE = 1 // 游客 接收授权拿到guest_id的用户 IDTYPE_Guest IDTYPE = 2 ) type UserInfo struct { UserId int64 `json:"user_id"` GuestId int64 `json:"guest_id"` } // GetIdType 用户确认用户身份类型 func (info *UserInfo) GetIdType() IDTYPE { if info.UserId != 0 { return IDTYPE_User } if info.GuestId != 0 { return IDTYPE_Guest } return IDTYPE_Onlooker } // IsUser 用户是不是登录用户 func (info *UserInfo) IsUser() bool { return info.UserId != 0 } // IsGuest 用户是不是游客 func (info *UserInfo) IsGuest() bool { // 必须判断登录用户在前, 用户可能会携带以前是游客到注册的身份 if info.UserId != 0 { return false } if info.GuestId != 0 { return true } return false } // IsOnlooker 白板用户: 非登录用户, 非游客, 判断为白板用户 func (info *UserInfo) IsOnlooker() bool { return info.UserId != 0 && info.GuestId != 0 } type BackendUserInfo struct { UserId int64 `json:"user_id"` DepartmentId int64 `json:"department_id"` } // 获取登录信息 func GetUserInfoFormMapClaims(claims jwt.MapClaims) (*UserInfo, error) { userinfo := &UserInfo{} if userid, ok := claims["user_id"]; ok { uid, ok := userid.(float64) if !ok { err := errors.New(fmt.Sprint("parse uid form context err:", userid)) logx.Error("parse uid form context err:", err) return nil, err } userinfo.UserId = int64(uid) } else { err := errors.New(`userid not in claims`) logx.Error(`userid not in claims`) return nil, err } if guestid, ok := claims["guest_id"]; ok { gid, ok := guestid.(float64) if !ok { err := errors.New(fmt.Sprint("parse guestid form context err:", guestid)) logx.Error("parse guestid form context err:", err) return nil, err } userinfo.GuestId = int64(gid) } else { err := errors.New(`userid not in claims`) logx.Error(`userid not in claims`) return nil, err } return userinfo, nil } // GetBackendUserInfoFormMapClaims 获取后台登录信息 func GetBackendUserInfoFormMapClaims(claims jwt.MapClaims) (*BackendUserInfo, error) { userinfo := &BackendUserInfo{} if userid, ok := claims["user_id"]; ok { uid, ok := userid.(float64) if !ok { err := errors.New(fmt.Sprint("parse uid form context err:", userid)) logx.Error("parse uid form context err:", err) return nil, err } userinfo.UserId = int64(uid) } else { err := errors.New(`userid not in claims`) logx.Error(`userid not in claims`) return nil, err } return userinfo, nil } // GenerateJwtToken 网站jwt token生成 func GenerateJwtToken(accessSecret *string, accessExpire, nowSec int64, userid int64, guestid int64) (string, error) { claims := make(jwt.MapClaims) claims["exp"] = nowSec + accessExpire claims["iat"] = nowSec if userid == 0 && guestid == 0 { err := errors.New("userid and guestid cannot be 0 at the same time") logx.Error(err) return "", err } claims["user_id"] = userid claims["guest_id"] = guestid token := jwt.New(jwt.SigningMethodHS256) token.Claims = claims return token.SignedString([]byte(*accessSecret)) } // GenerateBackendJwtToken 后台jwt token生成 func GenerateBackendJwtToken(accessSecret *string, accessExpire, nowSec int64, userId int64, departmentId int64) (string, error) { claims := make(jwt.MapClaims) claims["exp"] = nowSec + accessExpire claims["iat"] = nowSec if userId == 0 { err := errors.New("userId cannot be 0 at the same time") logx.Error(err) return "", err } claims["user_id"] = userId claims["department_id"] = departmentId token := jwt.New(jwt.SigningMethodHS256) token.Claims = claims return token.SignedString([]byte(*accessSecret)) } func getJwtClaims(AuthKey string, AccessSecret *string) (jwt.MapClaims, error) { token, err := jwt.Parse(AuthKey, func(token *jwt.Token) (interface{}, error) { // 检查签名方法是否为 HS256 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } // 返回用于验证签名的密钥 return []byte(*AccessSecret), nil }) if err != nil { return nil, errors.New(fmt.Sprint("Error parsing token:", err)) } // 验证成功返回 if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { return claims, nil } return nil, errors.New(fmt.Sprint("Invalid token", err)) } func CheckValueRange[T comparable](v T, rangevalues ...T) bool { for _, rv := range rangevalues { if v == rv { return true } } return false }