147 lines
2.8 KiB
Go
147 lines
2.8 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"encoding/base64"
|
|
"encoding/gob"
|
|
"fmt"
|
|
"net/url"
|
|
)
|
|
|
|
type ConfirmationLink[T any] struct {
|
|
Secret []byte
|
|
DefaultQueryKey string // 默认key 是 token
|
|
link *url.URL
|
|
}
|
|
|
|
func NewConfirmationLink[T any](key []byte, UrlStr string) *ConfirmationLink[T] {
|
|
u, err := url.Parse(UrlStr)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return &ConfirmationLink[T]{
|
|
Secret: key,
|
|
DefaultQueryKey: "token",
|
|
link: u,
|
|
}
|
|
}
|
|
|
|
// Generate 序列化链接传入需求的obj
|
|
func (cl *ConfirmationLink[T]) Generate(obj *T) (string, error) {
|
|
|
|
token, err := cl.Encrypt(obj)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return cl.GenerateWithToken(token)
|
|
}
|
|
|
|
// GenerateWithToken 序列化url带token
|
|
func (cl *ConfirmationLink[T]) GenerateWithToken(token string) (string, error) {
|
|
|
|
q := cl.link.Query()
|
|
if q.Has(cl.DefaultQueryKey) {
|
|
q.Set(cl.DefaultQueryKey, token)
|
|
} else {
|
|
q.Add(cl.DefaultQueryKey, token)
|
|
}
|
|
|
|
// 生成确认链接
|
|
cl.link.RawQuery = q.Encode()
|
|
|
|
return cl.link.String(), nil
|
|
}
|
|
|
|
func fusenMakeKey(keysting string) []byte {
|
|
|
|
key := []byte(keysting)
|
|
|
|
var result [32]byte
|
|
|
|
// If key length is more than 32, truncate it
|
|
if len(key) > 32 {
|
|
key = key[:32]
|
|
}
|
|
|
|
// If key length is less than 32, replicate it until it reaches 32
|
|
for len(key) < 32 {
|
|
key = append(key, key...)
|
|
}
|
|
|
|
// Only take the first 32 bytes
|
|
key = key[:32]
|
|
|
|
// Swap the first 16 bytes with the last 16 bytes
|
|
copy(result[:], key[16:])
|
|
copy(result[16:], key[:16])
|
|
|
|
return result[:]
|
|
}
|
|
|
|
func (cl *ConfirmationLink[T]) Encrypt(obj *T) (string, error) {
|
|
|
|
var buf = bytes.NewBuffer(nil)
|
|
err := gob.NewEncoder(buf).Encode(obj)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
block, err := aes.NewCipher(cl.Secret)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
nonce := make([]byte, 12)
|
|
// if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
// return "", err
|
|
// }
|
|
|
|
aesgcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
ciphertext := aesgcm.Seal(nonce, nonce, buf.Bytes(), nil)
|
|
|
|
return base64.URLEncoding.EncodeToString(ciphertext), nil
|
|
}
|
|
|
|
func (cl *ConfirmationLink[T]) Decrypt(ciphertext string) (*T, error) {
|
|
block, err := aes.NewCipher(cl.Secret)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ct, err := base64.URLEncoding.DecodeString(ciphertext)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(ct) < 12 {
|
|
return nil, fmt.Errorf("ciphertext too short")
|
|
}
|
|
|
|
aesgcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
plaintext, err := aesgcm.Open(nil, ct[:12], ct[12:], nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 解出golang的结构体
|
|
var protected T
|
|
var buf = bytes.NewBuffer(plaintext)
|
|
err = gob.NewDecoder(buf).Decode(&protected)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &protected, nil
|
|
}
|