package encryption_decryption import ( "bytes" "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/base64" "encoding/gob" "fmt" "io" "sync" ) func DerivationKeyV1(keysting string) []byte { key := []byte(keysting) var result [16]byte // If key length is more than 32, truncate it if len(key) > 16 { key = key[:16] } // If key length is less than 32, replicate it until it reaches 32 for len(key) < 16 { key = append(key, key...) } // Only take the first 32 bytes key = key[:16] // Swap the first 16 bytes with the last 16 bytes copy(result[:], key[8:]) copy(result[8:], key[:8]) return result[:] } type SecretGCM[T any] struct { srcKey string secretKey []byte derivationKey func(keysting string) []byte mu sync.Mutex EncDec ISecretEncDec } func NewSecretGCM[T any](key string) *SecretGCM[T] { s := &SecretGCM[T]{ srcKey: key, derivationKey: DerivationKeyV1, EncDec: base64.RawURLEncoding, } s.secretKey = s.derivationKey(s.srcKey) return s } func (s *SecretGCM[T]) UpdateDerivationKeyFunc(kfunc func(keysting string) []byte) { s.mu.Lock() defer s.mu.Unlock() s.derivationKey = kfunc s.secretKey = s.derivationKey(s.srcKey) } func (s *SecretGCM[T]) Encrypt(gobj *T) (string, error) { s.mu.Lock() defer s.mu.Unlock() var buf = bytes.NewBuffer(nil) err := gob.NewEncoder(buf).Encode(gobj) if err != nil { return "", err } block, err := aes.NewCipher(s.secretKey) if err != nil { return "", err } nonce := make([]byte, 12) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return "", err } aesgcm, err := cipher.NewGCM(block) if err != nil { return "", err } ciphertext := aesgcm.Seal(nonce, nonce, buf.Bytes(), nil) return s.EncDec.EncodeToString(ciphertext), nil } func (s *SecretGCM[T]) Decrypt(ciphertext string) (*T, error) { block, err := aes.NewCipher(s.secretKey) if err != nil { return nil, err } ct, err := s.EncDec.DecodeString(ciphertext) if err != nil { return nil, err } if len(ct) < 12 { return nil, fmt.Errorf("ciphertext too short") } aesgcm, err := cipher.NewGCM(block) if err != nil { return nil, err } plaintext, err := aesgcm.Open(nil, ct[:12], ct[12:], nil) if err != nil { return nil, err } // 解出golang的结构体 var protected T var buf = bytes.NewBuffer(plaintext) err = gob.NewDecoder(buf).Decode(&protected) if err != nil { return nil, err } return &protected, nil }