178 lines
4.8 KiB
Go
178 lines
4.8 KiB
Go
package auth
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
)
|
|
|
|
var DefaultJwtSecret uint64 = 21321321321
|
|
var DefaultDebugJwtSecret uint64 = 3285631123
|
|
|
|
func ParseJwtTokenWithHeader[T any](header string, r *http.Request) (string, *T, error) {
|
|
//TODO:
|
|
// var u T
|
|
// return "", &u, nil
|
|
|
|
AuthKey := r.Header.Get(header)
|
|
if AuthKey == "" {
|
|
return "", nil, nil
|
|
}
|
|
if len(AuthKey) <= 15 {
|
|
return "", nil, errors.New(fmt.Sprint("Error parsing token, len:", len(AuthKey)))
|
|
}
|
|
AuthKey = AuthKey[7:]
|
|
|
|
parts := strings.Split(AuthKey, ".")
|
|
if len(parts) != 3 {
|
|
return "", nil, fmt.Errorf("Invalid JWT token")
|
|
}
|
|
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("Error unmarshalling JWT DecodeString: %s", err.Error())
|
|
}
|
|
|
|
var p T
|
|
err = json.Unmarshal(payload, &p)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("Error unmarshalling JWT payload: %s", err)
|
|
}
|
|
|
|
return AuthKey, &p, nil
|
|
|
|
// token, err := jwt.Parse(AuthKey, func(token *jwt.Token) (interface{}, error) {
|
|
// // 检查签名方法是否为 HS256
|
|
// if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
// return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
// }
|
|
// // 返回用于验证签名的密钥
|
|
// return []byte(svcCtx.Config.Auth.AccessSecret), nil
|
|
// })
|
|
// if err != nil {
|
|
// return nil, errors.New(fmt.Sprint("Error parsing token:", err))
|
|
// }
|
|
|
|
// // 验证成功返回
|
|
// if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
|
// return claims, nil
|
|
// }
|
|
|
|
// return nil, errors.New(fmt.Sprint("Invalid token", err))
|
|
}
|
|
|
|
func TParseJwtTokenHeader[T any](AuthKey string) (string, *T, error) {
|
|
//TODO:
|
|
// var u T
|
|
// return "", &u, nil
|
|
|
|
parts := strings.Split(AuthKey, ".")
|
|
if len(parts) != 3 {
|
|
return "", nil, fmt.Errorf("Invalid JWT token")
|
|
}
|
|
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
log.Println(string(payload))
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("Error unmarshalling JWT DecodeString: %s", err.Error())
|
|
}
|
|
|
|
var p T
|
|
err = json.Unmarshal(payload, &p)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("Error unmarshalling JWT payload: %s", err)
|
|
}
|
|
|
|
return AuthKey, &p, nil
|
|
|
|
// token, err := jwt.Parse(AuthKey, func(token *jwt.Token) (interface{}, error) {
|
|
// // 检查签名方法是否为 HS256
|
|
// if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
// return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
// }
|
|
// // 返回用于验证签名的密钥
|
|
// return []byte(svcCtx.Config.Auth.AccessSecret), nil
|
|
// })
|
|
// if err != nil {
|
|
// return nil, errors.New(fmt.Sprint("Error parsing token:", err))
|
|
// }
|
|
|
|
// // 验证成功返回
|
|
// if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
|
// return claims, nil
|
|
// }
|
|
|
|
// return nil, errors.New(fmt.Sprint("Invalid token", err))
|
|
}
|
|
|
|
func ParseDebugJwtTokenWithHeader(header string, r *http.Request) (*Debug, error) {
|
|
|
|
AuthKey := r.Header.Get(header)
|
|
if AuthKey == "" {
|
|
return nil, nil
|
|
}
|
|
if len(AuthKey) <= 15 {
|
|
return nil, errors.New(fmt.Sprint("Error parsing token, len:", len(AuthKey)))
|
|
}
|
|
// AuthKey = AuthKey[7:] 如果没有Bearer
|
|
|
|
claims, err := ParseJwtTokenUint64Secret(AuthKey, DefaultDebugJwtSecret)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var debug Debug
|
|
// 使用反射获取 Debug 结构体的类型和值
|
|
debugType := reflect.TypeOf(debug)
|
|
debugValue := reflect.ValueOf(&debug).Elem()
|
|
|
|
// 遍历 Debug 结构体的字段
|
|
for i := 0; i < debugType.NumField(); i++ {
|
|
field := debugType.Field(i)
|
|
tag := field.Tag.Get("json")
|
|
|
|
// 在 MapClaims 中查找对应的值
|
|
value, ok := claims[tag]
|
|
if !ok {
|
|
return nil, fmt.Errorf("`%s` tag is not exists", tag)
|
|
}
|
|
|
|
// 使用反射设置字段的值
|
|
fieldValue := debugValue.Field(i)
|
|
|
|
switch fieldValue.Kind() {
|
|
case reflect.String:
|
|
fieldValue.SetString(value.(string))
|
|
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64:
|
|
fieldValue.SetInt(int64(value.(float64)))
|
|
case reflect.Bool:
|
|
fieldValue.SetBool(value.(bool))
|
|
case reflect.Ptr: // 处理指针类型
|
|
if fieldValue.IsNil() { // 检查指针是否为零值
|
|
newValue := reflect.New(fieldValue.Type().Elem()) // 创建新的指针值
|
|
fieldValue.Set(newValue) // 将新值设置为字段的值
|
|
}
|
|
elemValue := fieldValue.Elem()
|
|
switch elemValue.Kind() {
|
|
case reflect.String:
|
|
elemValue.SetString(value.(string))
|
|
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64:
|
|
elemValue.SetInt(int64(value.(float64)))
|
|
case reflect.Bool:
|
|
elemValue.SetBool(value.(bool))
|
|
default:
|
|
return nil, fmt.Errorf("`%s` type is not supported", elemValue.Kind())
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("`%s` type is not supported", fieldValue.Kind())
|
|
}
|
|
}
|
|
|
|
return &debug, nil
|
|
}
|