proto/goutils/proto_build/main.go
2023-11-30 11:05:42 +08:00

1437 lines
37 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bufio"
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
"io"
"io/fs"
"log"
"os"
"os/exec"
"path"
"path/filepath"
"regexp"
"runtime"
"strings"
"text/template"
"unicode"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gopkg.in/ini.v1"
)
var workspaceDir string
var tpl *template.Template
func init() {
log.SetFlags(log.Llongfile)
var err error
workspaceDir, err = findProtoParentFolder()
if err != nil {
log.Panic(err)
}
err = os.Chdir(workspaceDir)
if err != nil {
log.Panicf("切换目录失败:%v\n", err)
}
tpl, err = template.ParseGlob(workspaceDir + "/proto/goutils/proto_build/tpls/*.tpl")
if err != nil {
panic(err)
}
}
func main() {
args := os.Args[1:] // 获取除去可执行文件名的命令行参数
if len(args) == 0 || args[0] == "service" {
ServiceMain()
} else if args[0] == "gateway" {
GatewayMain()
} else {
log.Println("Invalid argument. Usage: go run main.go [service|gateway]")
}
}
func ServiceMain() {
log.Println("项目目录:", workspaceDir)
checkProtoMessageName(workspaceDir + "/" + "proto/service")
packageName := "service"
ServerCodePath := "server"
genDir := "gen/go"
ServiceName, projectName, projectLastName := getServiceNameAndProjectName(ServerCodePath)
if ServiceName == "" {
err := os.MkdirAll("server", 0755)
if err != nil {
log.Println(err)
}
createFileWithPermNotExists("server/service_config.ini", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "service_config.tpl", nil)
})
panic("必须填写server/service_config.ini文件的项目名称与proto对应名称")
}
defer createFileNotExists("protoc.sh", 0755, func(f io.Writer) error {
shellstr := `#! /bin/bash
go run -gcflags="-N" proto/goutils/proto_build/main.go`
f.Write([]byte(shellstr))
return nil
})
serviceProtoDir := fmt.Sprintf("proto/%s", packageName)
err := os.MkdirAll(genDir+"/"+packageName, 0755)
if err != nil {
log.Println(err)
}
CheckGomodBasicPackage()
ExecProtoc(workspaceDir, serviceProtoDir, genDir, packageName, projectName)
ExecCreateTest(genDir, projectLastName, packageName, ServiceName)
ExecCreateConfig(ServiceName, projectLastName)
ExecCreateAutoGrpc(genDir, packageName)
ExecCreateAutoLogic(workspaceDir, ServiceName, genDir, packageName, projectLastName)
// log.Println("Found proto folder at:", workerSpaceDir, ServiceNames, projectLastName, projectName)
}
func GatewayMain() {
log.Println("项目目录:", workspaceDir)
checkProtoMessageName(workspaceDir + "/" + "proto/service")
packageName := "service"
ServerCodePath := "server"
genDir := "gen/go"
ServiceName, projectName, projectLastName := getServiceNameAndProjectName(ServerCodePath)
if ServiceName == "" {
err := os.MkdirAll("server", 0755)
if err != nil {
log.Println(err)
}
createFileWithPermNotExists("server/service_config.ini", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "service_config.tpl", nil)
})
panic("必须填写server/service_config.ini文件的项目名称与proto对应名称")
}
defer createFileNotExists("protoc.sh", 0755, func(f io.Writer) error {
shellstr := `#! /bin/bash
go run -gcflags="-N" proto/goutils/proto_build/main.go gateway`
f.Write([]byte(shellstr))
return nil
})
// genDir := fmt.Sprintf("gen/go/%s", PackageName)
serviceProtoDir := fmt.Sprintf("proto/%s", packageName)
err := os.MkdirAll(genDir+"/"+packageName, 0755)
if err != nil {
log.Println(err)
}
ExecProtoc(workspaceDir, serviceProtoDir, genDir, packageName, projectName)
ExecCreateAutoGrpc(genDir, packageName)
ExecCreateConfig(ServiceName, projectLastName)
ExecCreateGatewayAutoGrpc(genDir, packageName, projectLastName)
}
func ExecCreateConfig(ServiceName, ProjectName string) {
err := os.MkdirAll("server/config", 0755)
if err != nil {
panic(err)
}
createFileWithPermNotExists("server/config/config.go", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "config.tpl", nil)
})
// name := underscoreToLowerCamelCase(ServiceName)
// createFileWithPermNotExists("server/main.go", func(f io.Writer) error {
// return tpl.ExecuteTemplate(f, "main.tpl", map[string]any{
// "ProjectName": ProjectName,
// "StructServiceNames": name,
// })
// })
createFileWithPermNotExists("server/main_test.go", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "main_test.tpl", nil)
})
createFileWithPermNotExists(".gitignore", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "gitignore.tpl", nil)
})
createFileNotExists("update_fspkg_master.sh", 0755, func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "update_fspkg_master.tpl", nil)
})
}
// 检查go.mod 添加 fusen的基础模块
func CheckGomodBasicPackage() {
data, err := os.ReadFile("./go.mod")
if err != nil {
panic(err)
}
var isRewrite bool = false
var content string = string(data)
if !regexp.MustCompile(`fusen-basic `).Match(data) {
log.Println(isRewrite, len(content))
isRewrite = true
content += "\n\nreplace fusen-basic v0.0.0 => gitlab.fusenpack.com/backend/basic v0.0.1"
}
if !regexp.MustCompile(`fusen-model `).Match(data) {
isRewrite = true
content += "\n\nreplace fusen-model v0.0.0 => gitlab.fusenpack.com/backend/model v0.0.1"
}
if isRewrite {
log.Println("rewrite go.mod => add <fusen-basic> <fusen-model>")
err = os.WriteFile("./go.mod", []byte(content), 0644)
if err != nil {
panic(err)
}
}
}
// 执行protoc
func ExecProtoc(workerSpaceDir, serviceProtoDir, genDir, packageName string, projectName string) {
allServiceNames := getAllServiceName()
protoCmdStr := fmt.Sprintf(`protoc -I %s --go_out %s --go_opt paths=source_relative --go-grpc_out %s --go-grpc_opt paths=source_relative --grpc-gateway_out %s --grpc-gateway_opt paths=source_relative`, "proto", genDir, genDir, genDir)
for _, sname := range allServiceNames {
protoCmdStr += importFileCmdStr(serviceNameEncode(packageName, sname), projectName)
}
importsFiles := checkServiceProtoImports(serviceProtoDir, allServiceNames...)
for _, importFile := range importsFiles {
protoCmdStr += importFileCmdStr(importFile, projectName)
}
log.Println(protoCmdStr)
cmd := exec.Command("sh", "-c", protoCmdStr)
cmd.Dir = workerSpaceDir
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// 执行命令
err := cmd.Run()
if err != nil {
log.Printf("命令执行失败:%v\n", err)
return
}
}
func ExecCreateAutoLogic(workerSpaceDir string, ServiceName string, genDir, packageName, projectName string) {
// snakeCaseServiceName := convertToSnakeCase(ServiceName)
grpcFilePath := genDir + "/" + packageName
// 获取目录下的所有文件
files, err := os.ReadDir(grpcFilePath)
if err != nil {
log.Println("Error reading directory:", err)
return
}
type StructService struct {
StructServiceName string
LogicPackageName string
}
type MainTpl struct {
ProjectName string
StructServiceNames []*StructService
LogicDirNames []string
}
var mtpl = &MainTpl{
ProjectName: projectName,
}
// 处理main.go文件
defer func() {
// name := underscoreToLowerCamelCase(ServiceName)
createFileWithPermNotExists("server/main.go", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "main.tpl", mtpl)
})
}()
// 遍历所有文件
for _, file := range files {
// 检查文件是否是 grpc.pb.go 文件
if !file.IsDir() && strings.HasSuffix(file.Name(), "grpc.pb.go") {
if !strings.HasPrefix(file.Name(), ServiceName) {
continue
}
func() {
// 获取文件的完整路径
filePath := filepath.Join(grpcFilePath, file.Name())
infos := ParseGrpcServerInfo(filePath)
// 在这里可以对文件进行进一步的处理
for _, info := range infos {
// logicStructNames = append(logicStructNames, info.StructName)
var methods []map[string]string
for _, met := range info.Method {
methodMap := map[string]string{
"StructName": info.StructName,
"MethodName": met.MethodName,
"ParamCtx": met.Params[0],
"ParamReq": met.Params[1],
"MethodReturn": met.Returns[0],
"MethodResponse": met.Returns[0][1:],
}
methods = append(methods, methodMap)
}
var ss *StructService = &StructService{}
logicPackageName := convertToSnakeCase(info.StructName)
ss.LogicPackageName = logicPackageName
var genTypesBuffer bytes.Buffer
err = tpl.ExecuteTemplate(&genTypesBuffer, "logic_grpc_struct.tpl", map[string]any{
"ProjectName": projectName,
"StructName": info.StructName,
"UnimplementedStructName": info.UnimplementedStructName,
"Methods": methods,
"PackageName": logicPackageName,
})
if err != nil {
panic(err)
}
logicPath := "server/logics/" + logicPackageName
mtpl.LogicDirNames = append(mtpl.LogicDirNames, logicPath)
err = os.MkdirAll(logicPath, 0755)
if err != nil {
panic(err)
}
f, err := os.OpenFile(logicPath+"/types_gen.go", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
panic(err)
}
// log.Println(genTypesBuffer.String())
formatted, err := format.Source(genTypesBuffer.Bytes())
if err != nil {
log.Printf("格式化代码失败:%v\n", err)
return
}
_, err = f.Write(formatted)
if err != nil {
panic(err)
}
// log.Println(genTypesBuffer.String())
createFilePath := logicPath + "/logic_init.go"
createFileWithPermNotExists(createFilePath, func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "logic_init.tpl", map[string]any{
"StructNames": []string{info.StructName},
"ProjectName": projectName,
"PackageName": logicPackageName,
})
})
ss.StructServiceName = info.ServiceName
for _, method := range methods {
fileName := convertToSnakeCase(method["MethodName"])
method["ProjectName"] = projectName
method["PackageName"] = logicPackageName
createFileWithPermNotExists(fmt.Sprintf("%s/%s_logic.go", logicPath, fileName), func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "logic_fusen_handler.tpl", method)
})
}
mtpl.StructServiceNames = append(mtpl.StructServiceNames, ss)
}
}()
}
}
}
func ExecCreateTest(genDir, projectName, packageName, serviceName string) {
genDir = genDir + "/" + packageName
var struCollection []*HttpGrpcMethodTest
for _, gwPath := range getSuffixFilesPath(genDir, strings.ToLower(serviceName)+".pb.gw.go") {
struCollection = append(struCollection, genGatewayTestFunction(gwPath)...)
}
var buf bytes.Buffer
err := tpl.ExecuteTemplate(&buf, "http_grpc_method_test.tpl", map[string]any{
"ProjectName": projectName,
"HttpGrpcTestStructs": struCollection,
})
if err != nil {
panic(err)
}
formatted, err := format.Source([]byte(buf.Bytes()))
if err != nil {
log.Printf("格式化代码失败:%v\n", err)
return
}
genTestDir := "server/test"
err = os.MkdirAll(genTestDir, 0755)
if err != nil {
log.Println(err)
}
filePath := fmt.Sprintf("%s/test_method_gen.go", genTestDir)
err = os.WriteFile(filePath, formatted, 0644)
if err != nil {
log.Printf("无法写入文件:%v\n", err)
return
}
for _, s := range struCollection {
snakeName := convertToSnakeCase(s.MethodName)
createFileWithPermNotExists(genTestDir+"/"+snakeName+"_test.go", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "http_grpc_method_file_test.tpl", map[string]any{
"ProjectName": projectName,
"RequestVar": s.RequestVar,
"MethodName": s.MethodName,
"RequestStruct": s.RequestStruct,
})
})
}
createFileWithPermNotExists(genTestDir+"/"+"var.go", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "http_grpc_method_var.tpl", map[string]any{
"ProjectName": projectName,
})
})
createFileNotExists("run_latest.sh", 0755, func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "run_latest.tpl", nil)
})
}
type ClientParam struct {
ClientName string
ServerName string
ServiceName string
GrpcServiceName string
}
// 执行auto grpc
func ExecCreateAutoGrpc(genDir, packageName string) {
genDir = genDir + "/" + packageName
var clientParams []ClientParam
for _, params := range getGrpcFileClientNames(genDir) {
clientParams = append(clientParams, ClientParam{
ClientName: params.ClientName,
ServerName: params.ServerName,
ServiceName: cases.Lower(language.English).String(params.ClientName),
GrpcServiceName: params.ServiceName,
})
}
var buf bytes.Buffer
tpl.ExecuteTemplate(&buf, "auto_grpc_nacos.tpl", map[string]any{
"PackageName": packageName,
"ClientParams": clientParams,
})
formatted, err := format.Source([]byte(buf.Bytes()))
if err != nil {
log.Printf("格式化代码失败:%v\n", err)
return
}
filePath := fmt.Sprintf("%s/auto_%s_grpc_nacos_client.pb.go", genDir, packageName)
err = os.WriteFile(filePath, formatted, 0644)
if err != nil {
log.Printf("无法写入文件:%v\n", err)
return
}
// log.Println(string(formatted))
}
// 执行auto grpc
func ExecCreateGatewayAutoGrpc(genDir, packageName, ProjectName string) {
genDir = genDir + "/" + packageName
var buf bytes.Buffer
err := tpl.ExecuteTemplate(&buf, "auto_grpc_gateway_nacos.tpl", nil)
if err != nil {
panic(err)
}
var funcNames []string
for _, gwPath := range getSuffixFilesPath(genDir, "pb.gw.go") {
funcNames = append(funcNames, genGatewayWithNacosFunction(gwPath, &buf)...)
}
formatted, err := format.Source(buf.Bytes())
if err != nil {
log.Printf("格式化代码失败:%v\n", err)
log.Println(buf.String())
return
}
filePath := fmt.Sprintf("%s/auto_%s_grpc_nacos_gateway.pb.go", genDir, packageName)
err = os.WriteFile(filePath, formatted, 0644)
if err != nil {
log.Printf("无法写入文件:%v\n", err)
return
}
err = os.MkdirAll("server/logic", 0755)
if err != nil {
log.Println(err)
}
// 生成gateway的代码logic
var grbuf bytes.Buffer
err = tpl.ExecuteTemplate(&grbuf, "auto_grpc_gateway_register.tpl", map[string]any{
"ProjectName": ProjectName,
"FuncNames": funcNames,
})
if err != nil {
panic(err)
}
formatted, err = format.Source(grbuf.Bytes())
if err != nil {
log.Printf("格式化代码失败:%v\n", err)
return
}
err = os.WriteFile("server/logic/gateway_logic_gen.go", formatted, 0644)
if err != nil {
log.Printf("无法写入文件:%v\n", err)
return
}
}
func genGatewayWithNacosFunction(grpcPath string, gatewayBuf *bytes.Buffer) (funcNames []string) {
// workerSpaceDir+"/gen/go/service/auth_grpc.pb.go"
// 解析Go源文件
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, grpcPath, nil, parser.AllErrors)
if err != nil {
log.Println("解析文件失败:", err)
os.Exit(1)
}
// packageName := node.Name.Name
// log.Println(packageName)
h := len("Register")
e := len("HandlerClient")
for _, decl := range node.Decls {
// 检查是否是接口声明
if fdec, ok := decl.(*ast.FuncDecl); ok && fdec.Name.IsExported() && strings.HasSuffix(fdec.Name.Name, "HandlerClient") {
// log.Println(fdec.Name.Name)
cName := fdec.Name.Name[h : len(fdec.Name.Name)-e]
fdec.Name.Name = fdec.Name.Name + "Nacos"
funcNames = append(funcNames, fdec.Name.Name)
field := fdec.Type.Params.List[len(fdec.Type.Params.List)-1]
field.Names = []*ast.Ident{{Name: "opts"}}
field.Type = &ast.Ellipsis{
Elt: &ast.SelectorExpr{
X: &ast.Ident{Name: "grpc"},
Sel: &ast.Ident{Name: "DialOption"},
},
}
for _, b := range fdec.Body.List {
if bexpr, ok := b.(*ast.ExprStmt); ok {
if pfunc, ok := bexpr.X.(*ast.CallExpr); ok && len(pfunc.Args) == 3 {
if x, ok := pfunc.Args[2].(*ast.FuncLit); ok {
for i, bb := range x.Body.List {
if as, ok := bb.(*ast.AssignStmt); ok {
if mycb, ok := as.Rhs[0].(*ast.CallExpr); ok {
if strings.HasPrefix(getTypeString(mycb.Fun), "request_") {
for i, arg := range mycb.Args {
// 检查参数是否为标识符且名称为 "client"
if ident, ok := arg.(*ast.Ident); ok && ident.Name == "client" {
// 将参数替换为函数调用
mycb.Args[i] = ast.NewIdent("grpcClient")
break
}
}
opts := ast.NewIdent("opts")
autoClientFunc := &ast.CallExpr{
Fun: ast.NewIdent("Auto" + cName + "ClientEx"),
Args: []ast.Expr{
ast.NewIdent("ctx"),
opts,
},
Ellipsis: opts.End(),
}
// 创建新的变量声明
var dstmt ast.Stmt = &ast.AssignStmt{
Tok: token.DEFINE,
Lhs: []ast.Expr{
ast.NewIdent("grpcClient"),
ast.NewIdent("err"),
},
Rhs: []ast.Expr{
autoClientFunc,
},
}
// // var insertIndex = 4
var blist []ast.Stmt
ifErr := x.Body.List[6]
blist = append(blist, x.Body.List[:i]...)
blist = append(blist, dstmt)
blist = append(blist, ifErr)
x.Body.List = append(blist, x.Body.List[i:]...)
// log.Println(getTypeString(mycb.Fun))
// printer.Fprint(os.Stdout, fset, as)
// io.WriteString(os.Stdout, "\n")
break
}
}
}
}
}
}
}
}
// 修改函数执行内容
// ast.Inspect(fdec.Body, func(n ast.Node) bool {
// // if reqCallback, ok := n.(*ast.CallExpr); ok {
// // // 在callExpr的上一级添加变量定义
// // // 获取函数主体中的最后一个语句
// // // var dstmt ast.Stmt = &ast.DeclStmt{Decl: varDecl}
// // // var insertIndex = 4
// // // var blist []ast.Stmt
// // // blist = append(blist, reqCallback.Body.List[:insertIndex]...)
// // // blist = append(blist, dstmt)
// // // reqCallback.Body.List = append(blist, reqCallback.Body.List[insertIndex:]...)
// // }
// if callExpr, ok := n.(*ast.CallExpr); ok {
// // 检查函数调用中的参数列表
// for i, arg := range callExpr.Args {
// // 检查参数是否为标识符且名称为 "client"
// if ident, ok := arg.(*ast.Ident); ok && ident.Name == "client" {
// // 将参数替换为函数调用
// opts := ast.NewIdent("opts")
// callExpr.Args[i] = &ast.CallExpr{
// Fun: ast.NewIdent("Auto" + cName + "ClientEx"),
// Args: []ast.Expr{
// ast.NewIdent("ctx"),
// opts,
// },
// Ellipsis: opts.End(),
// }
// break
// }
// }
// }
// return true
// })
// 打印修改后的函数
err = gatewayBuf.WriteByte('\n')
if err != nil {
panic(err)
}
printer.Fprint(gatewayBuf, fset, fdec)
}
}
// log.Println(infos)
return
}
type HttpGrpcMethodTest struct {
RequestVar string
RequestStruct string
MethodName string
HttpMethod string
UrlPath string
ServiceName string
}
func genGatewayTestFunction(grpcPath string) (createdCollection []*HttpGrpcMethodTest) {
// workerSpaceDir+"/gen/go/service/auth_grpc.pb.go"
// 解析Go源文件
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, grpcPath, nil, parser.AllErrors)
if err != nil {
log.Println("解析文件失败:", err)
os.Exit(1)
}
for _, decl := range node.Decls {
// 检查是否是接口声明
if fdec, ok := decl.(*ast.FuncDecl); ok && fdec.Name.IsExported() && strings.HasSuffix(fdec.Name.Name, "HandlerServer") {
created := &HttpGrpcMethodTest{}
// 修改函数执行内容
ast.Inspect(fdec.Body, func(n ast.Node) bool {
if callExpr, ok := n.(*ast.CallExpr); ok {
if expr, ok := callExpr.Fun.(*ast.Ident); ok && strings.HasPrefix(expr.Name, "local_request") {
// log.Println(expr.Obj.Decl.(*ast.FuncDecl))
created.ServiceName = strings.Split(expr.Name, "_")[2]
ast.Inspect(expr.Obj.Decl.(*ast.FuncDecl).Body, func(n ast.Node) bool {
if protoValue, ok := n.(*ast.ValueSpec); ok && protoValue.Names[0].Name == "protoReq" {
// log.Println(getTypeString(protoValue.Type))
created.RequestStruct = getTypeString(protoValue.Type)
}
if expr, ok := n.(*ast.SelectorExpr); ok {
if exprServer, ok := expr.X.(*ast.Ident); ok && exprServer.Name == "server" {
// log.Println(expr.Sel.Name, expr.X)
created.MethodName = expr.Sel.Name
createdCollection = append(createdCollection, created)
created.RequestVar = "var" + created.MethodName + "Req"
created = &HttpGrpcMethodTest{}
return false
}
}
return true
})
return true
}
// log.Println(callExpr.Fun)
// getTypeString(callExpr.Fun)
if expr, ok := callExpr.Fun.(*ast.SelectorExpr); ok && expr.Sel.Name == "Handle" {
var x = callExpr.Args[0].(*ast.BasicLit)
// log.Println(x, x.Value)
created.HttpMethod = x.Value
ast.Inspect(callExpr, func(n ast.Node) bool {
if callExprHandler, ok := n.(*ast.CallExpr); ok {
if expr, ok := callExprHandler.Fun.(*ast.SelectorExpr); ok && expr.Sel.Name == "WithHTTPPathPattern" {
var x = callExprHandler.Args[0].(*ast.BasicLit)
created.UrlPath = x.Value
}
}
return true
})
return true
}
}
return true
})
}
}
return
}
func getSuffixFilesPath(dir string, suffix string) (result []string) {
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if strings.HasSuffix(info.Name(), suffix) {
result = append(result, path)
}
return nil
})
if err != nil {
panic(err)
}
return
}
// 获取相关后缀的文件路径
func getGrpcFileClientNames(genDir string) (result []*ClientParam) {
err := filepath.Walk(genDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if strings.HasSuffix(info.Name(), "grpc.pb.go") {
params := parseGoFile(path)
result = append(result, params...)
// log.Println("找到的服务名称:", strings.Join(serverNames, " "))
}
return nil
})
if err != nil {
log.Printf("遍历目录失败:%v\n", err)
return nil
}
return
}
func importFileCmdStr(importFile string, projectName string) string {
shellStr := " {{.ImportFile}} --go_opt=M{{.ImportFile}}={{.ProjectName}} --go-grpc_opt=M{{.ImportFile}}={{.ProjectName}} --grpc-gateway_opt=M{{.ImportFile}}={{.ProjectName}} "
tmpl, err := template.New("shell").Parse(shellStr)
if err != nil {
log.Printf("模板解析失败:%v\n", err)
return ""
}
data := struct {
ImportFile string
ProjectName string
}{
ImportFile: importFile,
ProjectName: projectName,
}
var buf bytes.Buffer
err = tmpl.Execute(&buf, data)
if err != nil {
log.Printf("模板执行失败:%v\n", err)
return ""
}
return buf.String()
}
func serviceNameEncode(serviceProtoDir string, serviceName string) string {
return fmt.Sprintf("%s/%s.proto", serviceProtoDir, serviceName)
}
func checkServiceProtoImports(serviceProtoDir string, serviceNames ...string) []string {
var resultMap map[string]bool = make(map[string]bool)
for _, sname := range serviceNames {
fname := fmt.Sprintf("%s/%s.proto", serviceProtoDir, sname)
file, err := os.Open(fname)
if err != nil {
panic(fmt.Sprintf("无法打开文件: %s\n", err))
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "//") {
continue
}
for _, r := range regexp.MustCompile(`import\s+"(service/.*\.proto)"`).FindAllStringSubmatch(line, -1) {
resultMap[r[1]] = true
}
// log.Println(line)
}
if err := scanner.Err(); err != nil {
log.Printf("读取文件失败: %s\n", err)
}
if err != nil {
panic(err)
}
}
var result []string
for k := range resultMap {
result = append(result, k)
}
return result
}
func getAllServiceName() (result []string) {
dirEntrys, err := os.ReadDir("proto/service")
if err != nil {
panic(err)
}
for _, e := range dirEntrys {
result = append(result, e.Name()[0:len(e.Name())-6])
}
return
}
// 获取项目的参数
func getServiceNameAndProjectName(dir string) (serviceName string, projectName string, goModeName string) {
ifile, err := ini.Load(dir + "/service_config.ini")
if err != nil {
err = os.MkdirAll("server", 0755)
if err != nil {
log.Println(err)
}
createFileWithPermNotExists(dir+"/service_config.ini", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "service_config.tpl", nil)
})
panic("必须填写server/service_config.ini文件的项目名称与proto对应名称")
}
ses := ifile.Section("DEFAULT")
if err != nil {
panic(err)
}
key, err := ses.GetKey("SERVICE_NAME")
if err != nil {
panic(err)
}
serviceName = key.String()
if serviceName == "" {
panic("必须填写server/service_config.ini文件的项目名称与proto对应名称")
}
key, err = ses.GetKey("PROJECT_NAME")
if err != nil {
panic(err)
}
projectName = key.String()
if projectName == "" {
panic("必须填写server/service_config.ini文件的ProjectName称与仓库对应名称")
}
paths := strings.Split(projectName, "/")
goModeName = paths[len(paths)-1]
moddata, err := os.ReadFile("go.mod")
if err != nil {
panic(err)
}
result := regexp.MustCompile(`module\s+([a-zA-Z_\-]+)`).FindAllStringSubmatch(string(moddata), 1)
if len(result) == 0 {
panic("无法找到go.mod 获取 module信息")
}
goModeName = result[0][1]
return
}
// 获取proto文件的父目录
func findProtoParentFolder() (string, error) {
cwd, err := os.Getwd()
if err != nil {
return "", err
}
return findProtoFolderRecursive(cwd)
}
func findProtoFolderRecursive(dir string) (string, error) {
if dir == "/" || dir == "." {
return "", fmt.Errorf("proto folder not found")
}
protoDir := filepath.Join(dir, "proto")
if _, err := os.Stat(protoDir); err == nil {
return dir, nil
}
return findProtoFolderRecursive(filepath.Dir(dir))
}
func parseGoFile(filePath string) (ClientParams []*ClientParam) {
// 创建文件集合
fset := token.NewFileSet()
// 解析文件
file, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments)
if err != nil {
log.Printf("解析文件失败:%v\n", err)
return
}
// 获取包名
// pkgName := file.Name.Name
// log.Println("Package:", pkgName)
var Param *ClientParam = &ClientParam{}
// 遍历文件中的声明
for _, decl := range file.Decls {
if ServiceName, GrpcServiceName, ok := getGrpcFileServiceName(decl); ok {
Param.ServiceName = ServiceName
Param.GrpcServiceName = GrpcServiceName
ClientParams = append(ClientParams, Param)
Param = &ClientParam{}
}
// 检查是否是类型声明
if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.TYPE {
// 检查是否是结构体声明
for _, spec := range genDecl.Specs {
if typeSpec, ok := spec.(*ast.TypeSpec); ok {
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
// 获取结构体名称
// ast.Print(fset, decl)
structName := typeSpec.Name.Name
// log.Println("Struct: ", structName)
if strings.HasSuffix(structName, "Server") {
Param.ServerName = structName[len("Unimplemented") : len(structName)-6]
}
// 遍历结构体的字段列表
for _, field := range structType.Fields.List {
// 获取字段名称
field.End()
// fieldName := field.Names[0].Name
// log.Println("Field:", fieldName)
}
} else if _, ok := typeSpec.Type.(*ast.InterfaceType); ok {
interfaceName := typeSpec.Name.Name
if strings.HasSuffix(interfaceName, "Client") {
Param.ClientName = interfaceName[0 : len(interfaceName)-6]
}
}
}
}
} else if funcDecl, ok := decl.(*ast.FuncDecl); ok {
funcDecl.End()
// 检查是否是函数声明
// funcName := funcDecl.Name.Name
// log.Printf("Function: %s\n", funcName)
// log.Printf("Signature: %s\n", funcDecl.Type.Params)
// Print the function body, if it has one
// if funcDecl.Body != nil {
// log.Printf("Body: \n%s\n", funcDecl.Body.List)
// }
}
}
return
}
type GrpcServerMethod struct {
MethodName string
Params []string
Returns []string
}
type GrpcServerInfo struct {
ServiceName string
StructName string
UnimplementedStructName string
Method []*GrpcServerMethod
}
func ParseGrpcServerInfo(grpcPath string) (infos []*GrpcServerInfo) {
// workerSpaceDir+"/gen/go/service/auth_grpc.pb.go"
// 解析Go源文件
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, grpcPath, nil, parser.AllErrors)
if err != nil {
log.Println("解析文件失败:", err)
os.Exit(1)
}
packageName := node.Name.Name
// 遍历文件中的所有声明
for _, decl := range node.Decls {
// 检查是否是接口声明
if iface, ok := decl.(*ast.GenDecl); ok && iface.Tok == token.TYPE {
for _, spec := range iface.Specs {
// 检查是否是接口类型
if ifaceSpec, ok := spec.(*ast.TypeSpec); ok {
if ifaceType, ok := ifaceSpec.Type.(*ast.InterfaceType); ok {
// 打印接口名
// log.Println("接口名:", ifaceSpec.Name.Name)
if !strings.HasSuffix(ifaceSpec.Name.Name, "Server") {
continue
}
if strings.HasPrefix(ifaceSpec.Name.Name, "Unsafe") {
continue
}
ServiceName := ifaceSpec.Name.Name[0 : len(ifaceSpec.Name.Name)-6]
info := &GrpcServerInfo{
ServiceName: ServiceName,
StructName: ServiceName + "Logic",
}
info.UnimplementedStructName = "service.Unimplemented" + ServiceName + "Server"
infos = append(infos, info)
// 打印接口方法
for _, method := range ifaceType.Methods.List {
if !isUpper(method.Names[0].Name) {
continue
}
m := &GrpcServerMethod{}
info.Method = append(info.Method, m)
// 方法名称
// log.Println("方法名:", method.Names[0].Name)
// MethodName := method.Names[0].Name
m.MethodName = method.Names[0].Name
// 方法参数
if len(method.Type.(*ast.FuncType).Params.List) > 0 {
// log.Println("参数:")
// params := method.Type.(*ast.FuncType).Params
// log.Println(params.NumFields(), params.List)
// log.Println(string(src[method.Pos()-1 : method.End()-1]))
for _, field := range method.Type.(*ast.FuncType).Params.List {
// log.Printf("%s %s\n", field.Names, getTypeString(field.Type, packageName, 0))
m.Params = append(m.Params, getTypeString(field.Type, packageName))
}
} else {
log.Println("无参数")
}
// 方法返回值
if method.Type.(*ast.FuncType).Results != nil {
for _, field := range method.Type.(*ast.FuncType).Results.List {
m.Returns = append(m.Returns, getTypeString(field.Type, packageName))
}
} else {
log.Println("无返回值")
}
// log.Println()
}
}
}
}
}
}
// log.Println(infos)
return
}
func isUpper(name string) bool {
return name[:1] == strings.ToUpper(name[:1])
}
func getTypeString(expr ast.Expr, packageName ...string) string {
if len(packageName) > 0 {
return _getTypeString(expr, &packageName[0], 0)
} else {
return _getTypeString(expr, nil, 0)
}
}
// 获取类型字符串
func _getTypeString(expr ast.Expr, packageName *string, level int) string {
switch t := expr.(type) {
case *ast.Ident:
if level == 0 && isUpper(t.Name) {
if packageName != nil {
return fmt.Sprintf("%s.%s", *packageName, t.Name)
} else {
return fmt.Sprintf("%s", t.Name)
}
}
return t.Name
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", _getTypeString(t.X, packageName, level+1), t.Sel.Name)
case *ast.StarExpr:
return fmt.Sprintf("*%s", _getTypeString(t.X, packageName, level))
case *ast.ArrayType:
return fmt.Sprintf("[]%s", _getTypeString(t.Elt, packageName, level+1))
case *ast.MapType:
return fmt.Sprintf("map[%s]%s", _getTypeString(t.Key, packageName, level+1), _getTypeString(t.Value, packageName, level+1))
case *ast.ChanType:
dir := ""
switch t.Dir {
case ast.SEND:
dir = "chan<-"
case ast.RECV:
dir = "<-chan"
}
return fmt.Sprintf("%s %s", dir, _getTypeString(t.Value, packageName, level+1))
case *ast.StructType:
return "struct{}"
case *ast.InterfaceType:
return "interface{}"
case *ast.FuncType:
var params []string
if t.Params != nil {
for _, param := range t.Params.List {
params = append(params, _getTypeString(param.Type, packageName, level+1))
}
}
var results []string
if t.Results != nil {
for _, result := range t.Results.List {
results = append(results, _getTypeString(result.Type, packageName, level+1))
}
}
return fmt.Sprintf("func(%s) %s", strings.Join(params, ", "), strings.Join(results, ", "))
default:
panic(fmt.Sprintf("%s", t))
}
}
func createFileWithPermNotExists(filename string, do func(f io.Writer) error) error {
// 检测文件是否存在
_, err := os.Stat(filename)
if os.IsNotExist(err) {
// 文件不存在,创建文件
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
var buf = bytes.NewBuffer(nil)
err = do(buf)
if err != nil {
panic(err)
}
data, err := format.Source(buf.Bytes())
if err != nil {
_, err = file.Write(buf.Bytes())
} else {
_, err = file.Write(data)
}
if err != nil {
panic(err)
}
log.Printf("%s 文件已创建并写入内容\n", filename)
} else if err != nil {
// 发生其他错误
return err
} else {
// 文件已存在
// log.Printf("%s 文件已存在\n", filename)
}
return nil
}
func createFileNotExists(filename string, perm fs.FileMode, do func(f io.Writer) error) error {
// 检测文件是否存在
_, err := os.Stat(filename)
if os.IsNotExist(err) {
// 文件不存在,创建文件
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, perm)
if err != nil {
return err
}
defer file.Close()
var buf = bytes.NewBuffer(nil)
err = do(buf)
if err != nil {
panic(err)
}
_, err = file.Write(buf.Bytes())
if err != nil {
panic(err)
}
log.Printf("%s 文件已创建并写入内容\n", filename)
} else if err != nil {
// 发生其他错误
panic(err)
} else {
// 文件已存在
// log.Printf("%s 文件已存在\n", filename)
}
return nil
}
func convertToSnakeCase(name string) string {
var result strings.Builder
for i, char := range name {
if unicode.IsUpper(char) {
if i > 0 {
result.WriteRune('_')
}
result.WriteRune(unicode.ToLower(char))
} else {
result.WriteRune(char)
}
}
return result.String()
}
// 获取当前执行文件绝对路径go run
func getCurrentAbPathByCaller() string {
var abPath string
_, filename, _, ok := runtime.Caller(0)
if ok {
abPath = path.Dir(filename)
}
return abPath
}
func checkProtoMessageName(folderPath string) {
messageNames := make(map[string][]string)
// 遍历文件夹
err := filepath.Walk(folderPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 检查是否是 .proto 文件
if filepath.Ext(path) == ".proto" {
// 读取文件内容
content, err := os.ReadFile(path)
if err != nil {
return err
}
// 提取消息类型的名称
matches := regexp.MustCompile(`message\s+(\w+)\s+{`).FindAllStringSubmatch(string(content), -1)
for _, match := range matches {
messageName := match[1]
if paths, ok := messageNames[messageName]; ok {
// 发现重复的消息类型名称
panic(fmt.Sprintf("重复的消息类型名称 '%s' 在文件 '%s' 和之前的文件 %s 中存在重复。\n", messageName, path, paths[0]))
}
messageNames[messageName] = append(messageNames[messageName], path)
}
}
return nil
})
if err != nil {
panic(fmt.Sprintln("遍历文件夹时发生错误:", err))
}
}
// 下划线转驼峰
func underscoreToLowerCamelCase(s string) string {
// 分割字符串
words := strings.Split(s, "_")
caser := cases.Title(language.English)
for i := range words {
words[i] = caser.String(strings.Trim(words[i], " "))
}
return strings.Join(words, "")
}
func getGrpcFileServiceName(decl ast.Decl) (string, string, bool) {
if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.VAR {
// 遍历变量规范
for _, spec := range genDecl.Specs {
// 检查是否为值规范
if valueSpec, ok := spec.(*ast.ValueSpec); ok {
// 遍历变量名和变量值
for i, name := range valueSpec.Names {
log.Println(name)
value := valueSpec.Values[i]
// ast.Print(fset, value)
// 检查变量类型是否为 struct
// 检查变量类型是否为 struct
if v, ok := value.(*ast.CompositeLit); ok {
// 遍历 struct 的字段
if kv := v.Elts[0].(*ast.KeyValueExpr); ok {
// log.Println(kv.Key)
// log.Println(kv.Value)
if _sname, ok := kv.Value.(*ast.BasicLit); ok {
strlist := strings.Split(strings.Trim(_sname.Value, "\""), ".")
if len(strlist) != 2 {
return "", "", false
}
serviceName, serviceSubName := strlist[0], strlist[1]
return serviceName, serviceSubName, true
}
}
}
}
}
}
}
return "", "", false
}