package auth import ( "bytes" "crypto/aes" "crypto/cipher" "encoding/base64" "encoding/gob" "fmt" "net/url" ) type ConfirmationLink[T any] struct { Secret []byte DefaultQueryKey string // 默认key 是 token link *url.URL } func NewConfirmationLink[T any](key []byte, UrlStr string) *ConfirmationLink[T] { u, err := url.Parse(UrlStr) if err != nil { panic(err) } return &ConfirmationLink[T]{ Secret: key, DefaultQueryKey: "token", link: u, } } // Generate 序列化链接传入需求的obj func (cl *ConfirmationLink[T]) Generate(obj *T) (string, error) { token, err := cl.Encrypt(obj) if err != nil { return "", err } return cl.GenerateWithToken(token) } // GenerateWithToken 序列化url带token func (cl *ConfirmationLink[T]) GenerateWithToken(token string) (string, error) { q := cl.link.Query() if q.Has(cl.DefaultQueryKey) { q.Set(cl.DefaultQueryKey, token) } else { q.Add(cl.DefaultQueryKey, token) } // 生成确认链接 cl.link.RawQuery = q.Encode() return cl.link.String(), nil } func fusenMakeKey(keysting string) []byte { key := []byte(keysting) var result [32]byte // If key length is more than 32, truncate it if len(key) > 32 { key = key[:32] } // If key length is less than 32, replicate it until it reaches 32 for len(key) < 32 { key = append(key, key...) } // Only take the first 32 bytes key = key[:32] // Swap the first 16 bytes with the last 16 bytes copy(result[:], key[16:]) copy(result[16:], key[:16]) return result[:] } func (cl *ConfirmationLink[T]) Encrypt(obj *T) (string, error) { var buf = bytes.NewBuffer(nil) err := gob.NewEncoder(buf).Encode(obj) if err != nil { return "", err } block, err := aes.NewCipher(cl.Secret) if err != nil { return "", err } nonce := make([]byte, 12) // if _, err := io.ReadFull(rand.Reader, nonce); err != nil { // return "", err // } aesgcm, err := cipher.NewGCM(block) if err != nil { return "", err } ciphertext := aesgcm.Seal(nonce, nonce, buf.Bytes(), nil) return base64.URLEncoding.EncodeToString(ciphertext), nil } func (cl *ConfirmationLink[T]) Decrypt(ciphertext string) (*T, error) { block, err := aes.NewCipher(cl.Secret) if err != nil { return nil, err } ct, err := base64.URLEncoding.DecodeString(ciphertext) if err != nil { return nil, err } if len(ct) < 12 { return nil, fmt.Errorf("ciphertext too short") } aesgcm, err := cipher.NewGCM(block) if err != nil { return nil, err } plaintext, err := aesgcm.Open(nil, ct[:12], ct[12:], nil) if err != nil { return nil, err } // 解出golang的结构体 var protected T var buf = bytes.NewBuffer(plaintext) err = gob.NewDecoder(buf).Decode(&protected) if err != nil { return nil, err } return &protected, nil }