diff --git a/go.mod b/go.mod index 60d41c0f..002880ba 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,10 @@ module fusenapi go 1.20 -require github.com/zeromicro/go-zero v1.5.2 +require ( + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/zeromicro/go-zero v1.5.2 +) require ( github.com/beorn7/perks v1.0.1 // indirect diff --git a/go.sum b/go.sum index 1ba6a256..5bbedf7f 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,8 @@ github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= diff --git a/product/internal/handler/getproductlisthandler.go b/product/internal/handler/getproductlisthandler.go index efbe0594..22a3d962 100644 --- a/product/internal/handler/getproductlisthandler.go +++ b/product/internal/handler/getproductlisthandler.go @@ -1,6 +1,7 @@ package handler import ( + "fusenapi/utils/auth" "net/http" "fusenapi/product/internal/logic" @@ -11,14 +12,19 @@ import ( func GetProductListHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + //检测登录权限 + userInfo, err := auth.CheckAuth(r) + if err != nil { + httpx.ErrorCtx(r.Context(), w, err) + return + } var req types.GetProductListReq if err := httpx.Parse(r, &req); err != nil { httpx.ErrorCtx(r.Context(), w, err) return } - l := logic.NewGetProductListLogic(r.Context(), svcCtx) - resp, err := l.GetProductList(&req) + resp, err := l.GetProductList(&req, userInfo.UserId) if err != nil { httpx.ErrorCtx(r.Context(), w, err) } else { diff --git a/product/internal/logic/getproductlistlogic.go b/product/internal/logic/getproductlistlogic.go index b15c9e34..be15ff1c 100644 --- a/product/internal/logic/getproductlistlogic.go +++ b/product/internal/logic/getproductlistlogic.go @@ -2,6 +2,8 @@ package logic import ( "context" + "errors" + "fusenapi/model" "fusenapi/utils/image" "fusenapi/product/internal/svc" @@ -25,11 +27,19 @@ func NewGetProductListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Ge } // 获取产品列表 -func (l *GetProductListLogic) GetProductList(req *types.GetProductListReq) (resp *types.Response, err error) { +func (l *GetProductListLogic) GetProductList(req *types.GetProductListReq, uid int64) (resp *types.Response, err error) { //获取合适尺寸 if req.Size > 0 { req.Size = image.GetCurrentSize(req.Size) } //获取是否存在千人千面 + userModel := model.NewFsUserModel(l.svcCtx.MysqlConn) + userInfo, err := userModel.FindOne(l.ctx, uid) + if err != nil { + return nil, err + } + if userInfo.Id == 0 { + return nil, errors.New("user not exists") + } return } diff --git a/product/internal/svc/servicecontext.go b/product/internal/svc/servicecontext.go index d1c7fc03..0f1c06b7 100644 --- a/product/internal/svc/servicecontext.go +++ b/product/internal/svc/servicecontext.go @@ -1,19 +1,18 @@ package svc import ( - "fusenapi/model" "fusenapi/product/internal/config" "github.com/zeromicro/go-zero/core/stores/sqlx" ) type ServiceContext struct { - Config config.Config - FsProductModel model.FsProductModel + Config config.Config + MysqlConn sqlx.SqlConn } func NewServiceContext(c config.Config) *ServiceContext { return &ServiceContext{ - Config: c, - FsProductModel: model.NewFsProductModel(sqlx.NewMysql(c.DataSource)), + Config: c, + MysqlConn: sqlx.NewMysql(c.DataSource), } } diff --git a/product/product.go b/product/product.go index e5afc4c9..99047491 100644 --- a/product/product.go +++ b/product/product.go @@ -22,7 +22,6 @@ func main() { server := rest.MustNewServer(c.RestConf) defer server.Stop() - ctx := svc.NewServiceContext(c) handler.RegisterHandlers(server, ctx) diff --git a/utils/auth/auth.go b/utils/auth/auth.go new file mode 100644 index 00000000..df1e2487 --- /dev/null +++ b/utils/auth/auth.go @@ -0,0 +1,64 @@ +package auth + +import ( + "encoding/json" + "errors" + "github.com/golang-jwt/jwt" + "net/http" + "time" +) + +type UserInfo struct { + UserId int64 `json:"user_id"` +} + +// 签名key +var signKey = "FushenFGbhgfhgKgGH556HGlXrsfJKhhjYFGKLO==" +var expireTime = int64(3600) + +// 生成token +func GenJwtToken(userInfo UserInfo) (token string, err error) { + t := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": userInfo.UserId, + "exp": time.Now().Add(time.Second * time.Duration(expireTime)).Unix(), //过期时间 + "iss": "fusen", + }) + token, err = t.SignedString([]byte(signKey)) + if err != nil { + return "", err + } + return +} + +// 解释token +func ParseJwtToken(token string) (UserInfo, error) { + t, err := jwt.ParseWithClaims(token, jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(signKey), nil + }) + if err != nil { + return UserInfo{}, err + } + d, err := json.Marshal(t.Claims) + if err != nil { + return UserInfo{}, err + } + var userInfo UserInfo + if err = json.Unmarshal(d, &userInfo); err != nil { + return UserInfo{}, err + } + return userInfo, nil +} + +// 检测授权 +func CheckAuth(r *http.Request) (UserInfo, error) { + token := r.Header.Get("Authorization") + if token == "" { + return UserInfo{}, errors.New("token is required") + } + //解析token + userInfo, err := ParseJwtToken(token) + if err != nil { + return UserInfo{}, err + } + return userInfo, nil +}