fusenapi/utils/auth/confirmation_link.go
2023-07-27 10:18:49 +08:00

147 lines
2.8 KiB
Go

package auth
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"encoding/gob"
"fmt"
"net/url"
)
type ConfirmationLink[T any] struct {
Secret []byte
DefaultQueryKey string // 默认key 是 token
link *url.URL
}
func NewConfirmationLink[T any](key []byte, UrlStr string) *ConfirmationLink[T] {
u, err := url.Parse(UrlStr)
if err != nil {
panic(err)
}
return &ConfirmationLink[T]{
Secret: key,
DefaultQueryKey: "token",
link: u,
}
}
// Generate 序列化链接传入需求的obj
func (cl *ConfirmationLink[T]) Generate(obj *T) (string, error) {
token, err := cl.Encrypt(obj)
if err != nil {
return "", err
}
return cl.GenerateWithToken(token)
}
// GenerateWithToken 序列化url带token
func (cl *ConfirmationLink[T]) GenerateWithToken(token string) (string, error) {
q := cl.link.Query()
if q.Has(cl.DefaultQueryKey) {
q.Set(cl.DefaultQueryKey, token)
} else {
q.Add(cl.DefaultQueryKey, token)
}
// 生成确认链接
cl.link.RawQuery = q.Encode()
return cl.link.String(), nil
}
func fusenMakeKey(keysting string) []byte {
key := []byte(keysting)
var result [32]byte
// If key length is more than 32, truncate it
if len(key) > 32 {
key = key[:32]
}
// If key length is less than 32, replicate it until it reaches 32
for len(key) < 32 {
key = append(key, key...)
}
// Only take the first 32 bytes
key = key[:32]
// Swap the first 16 bytes with the last 16 bytes
copy(result[:], key[16:])
copy(result[16:], key[:16])
return result[:]
}
func (cl *ConfirmationLink[T]) Encrypt(obj *T) (string, error) {
var buf = bytes.NewBuffer(nil)
err := gob.NewEncoder(buf).Encode(obj)
if err != nil {
return "", err
}
block, err := aes.NewCipher(cl.Secret)
if err != nil {
return "", err
}
nonce := make([]byte, 12)
// if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
// return "", err
// }
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
ciphertext := aesgcm.Seal(nonce, nonce, buf.Bytes(), nil)
return base64.URLEncoding.EncodeToString(ciphertext), nil
}
func (cl *ConfirmationLink[T]) Decrypt(ciphertext string) (*T, error) {
block, err := aes.NewCipher(cl.Secret)
if err != nil {
return nil, err
}
ct, err := base64.URLEncoding.DecodeString(ciphertext)
if err != nil {
return nil, err
}
if len(ct) < 12 {
return nil, fmt.Errorf("ciphertext too short")
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
plaintext, err := aesgcm.Open(nil, ct[:12], ct[12:], nil)
if err != nil {
return nil, err
}
// 解出golang的结构体
var protected T
var buf = bytes.NewBuffer(plaintext)
err = gob.NewDecoder(buf).Decode(&protected)
if err != nil {
return nil, err
}
return &protected, nil
}