proto/goutils/proto_build/main.go

1426 lines
37 KiB
Go
Raw Normal View History

2023-11-27 09:36:02 +00:00
package main
import (
"bufio"
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
"io"
"io/fs"
"log"
"os"
"os/exec"
"path"
"path/filepath"
"regexp"
"runtime"
"strings"
"text/template"
"unicode"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gopkg.in/ini.v1"
)
var workspaceDir string
var tpl *template.Template
func init() {
log.SetFlags(log.Llongfile)
var err error
workspaceDir, err = findProtoParentFolder()
if err != nil {
log.Panic(err)
}
err = os.Chdir(workspaceDir)
if err != nil {
log.Panicf("切换目录失败:%v\n", err)
}
tpl, err = template.ParseGlob(workspaceDir + "/proto/goutils/proto_build/tpls/*.tpl")
if err != nil {
panic(err)
}
}
func main() {
args := os.Args[1:] // 获取除去可执行文件名的命令行参数
if len(args) == 0 || args[0] == "service" {
ServiceMain()
} else if args[0] == "gateway" {
GatewayMain()
} else {
log.Println("Invalid argument. Usage: go run main.go [service|gateway]")
}
}
func ServiceMain() {
log.Println("项目目录:", workspaceDir)
checkProtoMessageName(workspaceDir + "/" + "proto/service")
packageName := "service"
ServerCodePath := "server"
genDir := "gen/go"
ServiceName, projectName, projectLastName := getServiceNameAndProjectName(ServerCodePath)
if ServiceName == "" {
err := os.MkdirAll("server", 0755)
if err != nil {
log.Println(err)
}
createFileWithPermNotExists("server/service_config.ini", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "service_config.tpl", nil)
})
panic("必须填写server/service_config.ini文件的项目名称与proto对应名称")
}
defer createFileNotExists("protoc.sh", 0755, func(f io.Writer) error {
shellstr := `#! /bin/bash
go run -gcflags="-N" proto/goutils/proto_build/main.go`
f.Write([]byte(shellstr))
return nil
})
serviceProtoDir := fmt.Sprintf("proto/%s", packageName)
err := os.MkdirAll(genDir+"/"+packageName, 0755)
if err != nil {
log.Println(err)
}
CheckGomodBasicPackage()
ExecProtoc(workspaceDir, serviceProtoDir, genDir, packageName, projectName)
ExecCreateTest(genDir, projectLastName, packageName, ServiceName)
ExecCreateConfig(ServiceName, projectLastName)
ExecCreateAutoGrpc(genDir, packageName)
ExecCreateAutoLogic(workspaceDir, ServiceName, genDir, packageName, projectLastName)
// log.Println("Found proto folder at:", workerSpaceDir, ServiceNames, projectLastName, projectName)
}
func GatewayMain() {
log.Println("项目目录:", workspaceDir)
checkProtoMessageName(workspaceDir + "/" + "proto/service")
packageName := "service"
ServerCodePath := "server"
genDir := "gen/go"
ServiceName, projectName, projectLastName := getServiceNameAndProjectName(ServerCodePath)
if ServiceName == "" {
err := os.MkdirAll("server", 0755)
if err != nil {
log.Println(err)
}
createFileWithPermNotExists("server/service_config.ini", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "service_config.tpl", nil)
})
panic("必须填写server/service_config.ini文件的项目名称与proto对应名称")
}
defer createFileNotExists("protoc.sh", 0755, func(f io.Writer) error {
shellstr := `#! /bin/bash
go run -gcflags="-N" proto/goutils/proto_build/main.go gateway`
f.Write([]byte(shellstr))
return nil
})
// genDir := fmt.Sprintf("gen/go/%s", PackageName)
serviceProtoDir := fmt.Sprintf("proto/%s", packageName)
err := os.MkdirAll(genDir+"/"+packageName, 0755)
if err != nil {
log.Println(err)
}
ExecProtoc(workspaceDir, serviceProtoDir, genDir, packageName, projectName)
ExecCreateAutoGrpc(genDir, packageName)
ExecCreateConfig(ServiceName, projectLastName)
ExecCreateGatewayAutoGrpc(genDir, packageName, projectLastName)
}
func ExecCreateConfig(ServiceName, ProjectName string) {
err := os.MkdirAll("server/config", 0755)
if err != nil {
panic(err)
}
createFileWithPermNotExists("server/config/config.go", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "config.tpl", nil)
})
// name := underscoreToLowerCamelCase(ServiceName)
// createFileWithPermNotExists("server/main.go", func(f io.Writer) error {
// return tpl.ExecuteTemplate(f, "main.tpl", map[string]any{
// "ProjectName": ProjectName,
// "StructServiceNames": name,
// })
// })
createFileWithPermNotExists("server/main_test.go", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "main_test.tpl", nil)
})
createFileWithPermNotExists(".gitignore", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "gitignore.tpl", nil)
})
createFileNotExists("update_fspkg_master.sh", 0755, func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "update_fspkg_master.tpl", nil)
})
}
// 检查go.mod 添加 fusen的基础模块
func CheckGomodBasicPackage() {
data, err := os.ReadFile("./go.mod")
if err != nil {
panic(err)
}
var isRewrite bool = false
var content string = string(data)
if !regexp.MustCompile(`fusen-basic `).Match(data) {
log.Println(isRewrite, len(content))
isRewrite = true
content += "\n\nreplace fusen-basic v0.0.0 => gitee.com/fusenpack/fusen-basic v0.0.1"
}
if !regexp.MustCompile(`fusen-model `).Match(data) {
isRewrite = true
content += "\n\nreplace fusen-model v0.0.0 => gitee.com/fusenpack/fusen-model v0.0.1"
}
if isRewrite {
log.Println("rewrite go.mod => add <fusen-basic> <fusen-model>")
err = os.WriteFile("./go.mod", []byte(content), 0644)
if err != nil {
panic(err)
}
}
}
// 执行protoc
func ExecProtoc(workerSpaceDir, serviceProtoDir, genDir, packageName string, projectName string) {
allServiceNames := getAllServiceName()
protoCmdStr := fmt.Sprintf(`protoc -I %s --go_out %s --go_opt paths=source_relative --go-grpc_out %s --go-grpc_opt paths=source_relative --grpc-gateway_out %s --grpc-gateway_opt paths=source_relative`, "proto", genDir, genDir, genDir)
for _, sname := range allServiceNames {
protoCmdStr += importFileCmdStr(serviceNameEncode(packageName, sname), projectName)
}
importsFiles := checkServiceProtoImports(serviceProtoDir, allServiceNames...)
for _, importFile := range importsFiles {
protoCmdStr += importFileCmdStr(importFile, projectName)
}
log.Println(protoCmdStr)
cmd := exec.Command("sh", "-c", protoCmdStr)
cmd.Dir = workerSpaceDir
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// 执行命令
err := cmd.Run()
if err != nil {
log.Printf("命令执行失败:%v\n", err)
return
}
}
func ExecCreateAutoLogic(workerSpaceDir string, ServiceName string, genDir, packageName, projectName string) {
// snakeCaseServiceName := convertToSnakeCase(ServiceName)
grpcFilePath := genDir + "/" + packageName
// 获取目录下的所有文件
files, err := os.ReadDir(grpcFilePath)
if err != nil {
log.Println("Error reading directory:", err)
return
}
type StructService struct {
StructServiceName string
LogicPackageName string
}
type MainTpl struct {
ProjectName string
StructServiceNames []*StructService
LogicDirNames []string
}
var mtpl = &MainTpl{
ProjectName: projectName,
}
// 处理main.go文件
defer func() {
// name := underscoreToLowerCamelCase(ServiceName)
createFileWithPermNotExists("server/main.go", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "main.tpl", mtpl)
})
}()
// 遍历所有文件
for _, file := range files {
// 检查文件是否是 grpc.pb.go 文件
if !file.IsDir() && strings.HasSuffix(file.Name(), "grpc.pb.go") {
if !strings.HasPrefix(file.Name(), ServiceName) {
continue
}
func() {
// 获取文件的完整路径
filePath := filepath.Join(grpcFilePath, file.Name())
infos := ParseGrpcServerInfo(filePath)
// 在这里可以对文件进行进一步的处理
for _, info := range infos {
// logicStructNames = append(logicStructNames, info.StructName)
var methods []map[string]string
for _, met := range info.Method {
methodMap := map[string]string{
"StructName": info.StructName,
"MethodName": met.MethodName,
"ParamCtx": met.Params[0],
"ParamReq": met.Params[1],
"MethodReturn": met.Returns[0],
"MethodResponse": met.Returns[0][1:],
}
methods = append(methods, methodMap)
}
var ss *StructService = &StructService{}
logicPackageName := convertToSnakeCase(info.StructName)
ss.LogicPackageName = logicPackageName
var genTypesBuffer bytes.Buffer
err = tpl.ExecuteTemplate(&genTypesBuffer, "logic_grpc_struct.tpl", map[string]any{
"ProjectName": projectName,
"StructName": info.StructName,
"UnimplementedStructName": info.UnimplementedStructName,
"Methods": methods,
"PackageName": logicPackageName,
})
if err != nil {
panic(err)
}
logicPath := "server/logics/" + logicPackageName
mtpl.LogicDirNames = append(mtpl.LogicDirNames, logicPath)
err = os.MkdirAll(logicPath, 0755)
if err != nil {
panic(err)
}
f, err := os.OpenFile(logicPath+"/types_gen.go", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
panic(err)
}
// log.Println(genTypesBuffer.String())
formatted, err := format.Source(genTypesBuffer.Bytes())
if err != nil {
log.Printf("格式化代码失败:%v\n", err)
return
}
_, err = f.Write(formatted)
if err != nil {
panic(err)
}
// log.Println(genTypesBuffer.String())
createFilePath := logicPath + "/logic_init.go"
createFileWithPermNotExists(createFilePath, func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "logic_init.tpl", map[string]any{
"StructNames": []string{info.StructName},
"ProjectName": projectName,
"PackageName": logicPackageName,
})
})
ss.StructServiceName = info.ServiceName
for _, method := range methods {
fileName := convertToSnakeCase(method["MethodName"])
method["ProjectName"] = projectName
method["PackageName"] = logicPackageName
createFileWithPermNotExists(fmt.Sprintf("%s/%s_logic.go", logicPath, fileName), func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "logic_fusen_handler.tpl", method)
})
}
mtpl.StructServiceNames = append(mtpl.StructServiceNames, ss)
}
}()
}
}
}
func ExecCreateTest(genDir, projectName, packageName, serviceName string) {
genDir = genDir + "/" + packageName
var struCollection []*HttpGrpcMethodTest
for _, gwPath := range getSuffixFilesPath(genDir, strings.ToLower(serviceName)+".pb.gw.go") {
struCollection = append(struCollection, genGatewayTestFunction(gwPath)...)
}
var buf bytes.Buffer
err := tpl.ExecuteTemplate(&buf, "http_grpc_method_test.tpl", map[string]any{
"ProjectName": projectName,
"HttpGrpcTestStructs": struCollection,
})
if err != nil {
panic(err)
}
formatted, err := format.Source([]byte(buf.Bytes()))
if err != nil {
log.Printf("格式化代码失败:%v\n", err)
return
}
genTestDir := "server/test"
err = os.MkdirAll(genTestDir, 0755)
if err != nil {
log.Println(err)
}
filePath := fmt.Sprintf("%s/test_method_gen.go", genTestDir)
err = os.WriteFile(filePath, formatted, 0644)
if err != nil {
log.Printf("无法写入文件:%v\n", err)
return
}
for _, s := range struCollection {
snakeName := convertToSnakeCase(s.MethodName)
createFileWithPermNotExists(genTestDir+"/"+snakeName+"_test.go", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "http_grpc_method_file_test.tpl", map[string]any{
"ProjectName": projectName,
"RequestVar": s.RequestVar,
"MethodName": s.MethodName,
"RequestStruct": s.RequestStruct,
})
})
}
createFileWithPermNotExists(genTestDir+"/"+"var.go", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "http_grpc_method_var.tpl", map[string]any{
"ProjectName": projectName,
})
})
createFileNotExists("run_latest.sh", 0755, func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "run_latest.tpl", nil)
})
}
type ClientParam struct {
ClientName string
ServerName string
ServiceName string
GrpcServiceName string
}
// 执行auto grpc
func ExecCreateAutoGrpc(genDir, packageName string) {
genDir = genDir + "/" + packageName
var clientParams []ClientParam
for _, params := range getGrpcFileClientNames(genDir) {
clientParams = append(clientParams, ClientParam{
ClientName: params.ClientName,
ServerName: params.ServerName,
ServiceName: cases.Lower(language.English).String(params.ClientName),
GrpcServiceName: params.ServiceName,
})
}
var buf bytes.Buffer
tpl.ExecuteTemplate(&buf, "auto_grpc_nacos.tpl", map[string]any{
"PackageName": packageName,
"ClientParams": clientParams,
})
formatted, err := format.Source([]byte(buf.Bytes()))
if err != nil {
log.Printf("格式化代码失败:%v\n", err)
return
}
filePath := fmt.Sprintf("%s/auto_%s_grpc_nacos_client.pb.go", genDir, packageName)
err = os.WriteFile(filePath, formatted, 0644)
if err != nil {
log.Printf("无法写入文件:%v\n", err)
return
}
// log.Println(string(formatted))
}
// 执行auto grpc
func ExecCreateGatewayAutoGrpc(genDir, packageName, ProjectName string) {
genDir = genDir + "/" + packageName
var buf bytes.Buffer
err := tpl.ExecuteTemplate(&buf, "auto_grpc_gateway_nacos.tpl", nil)
if err != nil {
panic(err)
}
var funcNames []string
for _, gwPath := range getSuffixFilesPath(genDir, "pb.gw.go") {
funcNames = append(funcNames, genGatewayWithNacosFunction(gwPath, &buf)...)
}
formatted, err := format.Source(buf.Bytes())
if err != nil {
log.Printf("格式化代码失败:%v\n", err)
log.Println(buf.String())
return
}
filePath := fmt.Sprintf("%s/auto_%s_grpc_nacos_gateway.pb.go", genDir, packageName)
err = os.WriteFile(filePath, formatted, 0644)
if err != nil {
log.Printf("无法写入文件:%v\n", err)
return
}
err = os.MkdirAll("server/logic", 0755)
if err != nil {
log.Println(err)
}
// 生成gateway的代码logic
var grbuf bytes.Buffer
err = tpl.ExecuteTemplate(&grbuf, "auto_grpc_gateway_register.tpl", map[string]any{
"ProjectName": ProjectName,
"FuncNames": funcNames,
})
if err != nil {
panic(err)
}
formatted, err = format.Source(grbuf.Bytes())
if err != nil {
log.Printf("格式化代码失败:%v\n", err)
return
}
err = os.WriteFile("server/logic/gateway_logic_gen.go", formatted, 0644)
if err != nil {
log.Printf("无法写入文件:%v\n", err)
return
}
}
func genGatewayWithNacosFunction(grpcPath string, gatewayBuf *bytes.Buffer) (funcNames []string) {
// workerSpaceDir+"/gen/go/service/auth_grpc.pb.go"
// 解析Go源文件
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, grpcPath, nil, parser.AllErrors)
if err != nil {
log.Println("解析文件失败:", err)
os.Exit(1)
}
// packageName := node.Name.Name
// log.Println(packageName)
h := len("Register")
e := len("HandlerClient")
for _, decl := range node.Decls {
// 检查是否是接口声明
if fdec, ok := decl.(*ast.FuncDecl); ok && fdec.Name.IsExported() && strings.HasSuffix(fdec.Name.Name, "HandlerClient") {
// log.Println(fdec.Name.Name)
cName := fdec.Name.Name[h : len(fdec.Name.Name)-e]
fdec.Name.Name = fdec.Name.Name + "Nacos"
funcNames = append(funcNames, fdec.Name.Name)
field := fdec.Type.Params.List[len(fdec.Type.Params.List)-1]
field.Names = []*ast.Ident{{Name: "opts"}}
field.Type = &ast.Ellipsis{
Elt: &ast.SelectorExpr{
X: &ast.Ident{Name: "grpc"},
Sel: &ast.Ident{Name: "DialOption"},
},
}
for _, b := range fdec.Body.List {
if bexpr, ok := b.(*ast.ExprStmt); ok {
if pfunc, ok := bexpr.X.(*ast.CallExpr); ok && len(pfunc.Args) == 3 {
if x, ok := pfunc.Args[2].(*ast.FuncLit); ok {
for i, bb := range x.Body.List {
if as, ok := bb.(*ast.AssignStmt); ok {
if mycb, ok := as.Rhs[0].(*ast.CallExpr); ok {
if strings.HasPrefix(getTypeString(mycb.Fun), "request_") {
for i, arg := range mycb.Args {
// 检查参数是否为标识符且名称为 "client"
if ident, ok := arg.(*ast.Ident); ok && ident.Name == "client" {
// 将参数替换为函数调用
mycb.Args[i] = ast.NewIdent("grpcClient")
break
}
}
opts := ast.NewIdent("opts")
autoClientFunc := &ast.CallExpr{
Fun: ast.NewIdent("Auto" + cName + "ClientEx"),
Args: []ast.Expr{
ast.NewIdent("ctx"),
opts,
},
Ellipsis: opts.End(),
}
// 创建新的变量声明
var dstmt ast.Stmt = &ast.AssignStmt{
Tok: token.DEFINE,
Lhs: []ast.Expr{
ast.NewIdent("grpcClient"),
ast.NewIdent("err"),
},
Rhs: []ast.Expr{
autoClientFunc,
},
}
// // var insertIndex = 4
var blist []ast.Stmt
ifErr := x.Body.List[6]
blist = append(blist, x.Body.List[:i]...)
blist = append(blist, dstmt)
blist = append(blist, ifErr)
x.Body.List = append(blist, x.Body.List[i:]...)
// log.Println(getTypeString(mycb.Fun))
// printer.Fprint(os.Stdout, fset, as)
// io.WriteString(os.Stdout, "\n")
break
}
}
}
}
}
}
}
}
// 修改函数执行内容
// ast.Inspect(fdec.Body, func(n ast.Node) bool {
// // if reqCallback, ok := n.(*ast.CallExpr); ok {
// // // 在callExpr的上一级添加变量定义
// // // 获取函数主体中的最后一个语句
// // // var dstmt ast.Stmt = &ast.DeclStmt{Decl: varDecl}
// // // var insertIndex = 4
// // // var blist []ast.Stmt
// // // blist = append(blist, reqCallback.Body.List[:insertIndex]...)
// // // blist = append(blist, dstmt)
// // // reqCallback.Body.List = append(blist, reqCallback.Body.List[insertIndex:]...)
// // }
// if callExpr, ok := n.(*ast.CallExpr); ok {
// // 检查函数调用中的参数列表
// for i, arg := range callExpr.Args {
// // 检查参数是否为标识符且名称为 "client"
// if ident, ok := arg.(*ast.Ident); ok && ident.Name == "client" {
// // 将参数替换为函数调用
// opts := ast.NewIdent("opts")
// callExpr.Args[i] = &ast.CallExpr{
// Fun: ast.NewIdent("Auto" + cName + "ClientEx"),
// Args: []ast.Expr{
// ast.NewIdent("ctx"),
// opts,
// },
// Ellipsis: opts.End(),
// }
// break
// }
// }
// }
// return true
// })
// 打印修改后的函数
err = gatewayBuf.WriteByte('\n')
if err != nil {
panic(err)
}
printer.Fprint(gatewayBuf, fset, fdec)
}
}
// log.Println(infos)
return
}
type HttpGrpcMethodTest struct {
RequestVar string
RequestStruct string
MethodName string
HttpMethod string
UrlPath string
ServiceName string
}
func genGatewayTestFunction(grpcPath string) (createdCollection []*HttpGrpcMethodTest) {
// workerSpaceDir+"/gen/go/service/auth_grpc.pb.go"
// 解析Go源文件
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, grpcPath, nil, parser.AllErrors)
if err != nil {
log.Println("解析文件失败:", err)
os.Exit(1)
}
for _, decl := range node.Decls {
// 检查是否是接口声明
if fdec, ok := decl.(*ast.FuncDecl); ok && fdec.Name.IsExported() && strings.HasSuffix(fdec.Name.Name, "HandlerServer") {
created := &HttpGrpcMethodTest{}
// 修改函数执行内容
ast.Inspect(fdec.Body, func(n ast.Node) bool {
if callExpr, ok := n.(*ast.CallExpr); ok {
if expr, ok := callExpr.Fun.(*ast.Ident); ok && strings.HasPrefix(expr.Name, "local_request") {
// log.Println(expr.Obj.Decl.(*ast.FuncDecl))
created.ServiceName = strings.Split(expr.Name, "_")[2]
ast.Inspect(expr.Obj.Decl.(*ast.FuncDecl).Body, func(n ast.Node) bool {
if protoValue, ok := n.(*ast.ValueSpec); ok && protoValue.Names[0].Name == "protoReq" {
// log.Println(getTypeString(protoValue.Type))
created.RequestStruct = getTypeString(protoValue.Type)
}
if expr, ok := n.(*ast.SelectorExpr); ok {
if exprServer, ok := expr.X.(*ast.Ident); ok && exprServer.Name == "server" {
// log.Println(expr.Sel.Name, expr.X)
created.MethodName = expr.Sel.Name
createdCollection = append(createdCollection, created)
created.RequestVar = "var" + created.MethodName + "Req"
created = &HttpGrpcMethodTest{}
return false
}
}
return true
})
return true
}
// log.Println(callExpr.Fun)
// getTypeString(callExpr.Fun)
if expr, ok := callExpr.Fun.(*ast.SelectorExpr); ok && expr.Sel.Name == "Handle" {
var x = callExpr.Args[0].(*ast.BasicLit)
// log.Println(x, x.Value)
created.HttpMethod = x.Value
ast.Inspect(callExpr, func(n ast.Node) bool {
if callExprHandler, ok := n.(*ast.CallExpr); ok {
if expr, ok := callExprHandler.Fun.(*ast.SelectorExpr); ok && expr.Sel.Name == "WithHTTPPathPattern" {
var x = callExprHandler.Args[0].(*ast.BasicLit)
created.UrlPath = x.Value
}
}
return true
})
return true
}
}
return true
})
}
}
return
}
func getSuffixFilesPath(dir string, suffix string) (result []string) {
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if strings.HasSuffix(info.Name(), suffix) {
result = append(result, path)
}
return nil
})
if err != nil {
panic(err)
}
return
}
// 获取相关后缀的文件路径
func getGrpcFileClientNames(genDir string) (result []*ClientParam) {
err := filepath.Walk(genDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if strings.HasSuffix(info.Name(), "grpc.pb.go") {
params := parseGoFile(path)
result = append(result, params...)
// log.Println("找到的服务名称:", strings.Join(serverNames, " "))
}
return nil
})
if err != nil {
log.Printf("遍历目录失败:%v\n", err)
return nil
}
return
}
func importFileCmdStr(importFile string, projectName string) string {
shellStr := " {{.ImportFile}} --go_opt=M{{.ImportFile}}={{.ProjectName}} --go-grpc_opt=M{{.ImportFile}}={{.ProjectName}} --grpc-gateway_opt=M{{.ImportFile}}={{.ProjectName}} "
tmpl, err := template.New("shell").Parse(shellStr)
if err != nil {
log.Printf("模板解析失败:%v\n", err)
return ""
}
data := struct {
ImportFile string
ProjectName string
}{
ImportFile: importFile,
ProjectName: projectName,
}
var buf bytes.Buffer
err = tmpl.Execute(&buf, data)
if err != nil {
log.Printf("模板执行失败:%v\n", err)
return ""
}
return buf.String()
}
func serviceNameEncode(serviceProtoDir string, serviceName string) string {
return fmt.Sprintf("%s/%s.proto", serviceProtoDir, serviceName)
}
func checkServiceProtoImports(serviceProtoDir string, serviceNames ...string) []string {
var resultMap map[string]bool = make(map[string]bool)
for _, sname := range serviceNames {
fname := fmt.Sprintf("%s/%s.proto", serviceProtoDir, sname)
file, err := os.Open(fname)
if err != nil {
panic(fmt.Sprintf("无法打开文件: %s\n", err))
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "//") {
continue
}
for _, r := range regexp.MustCompile(`import\s+"(service/.*\.proto)"`).FindAllStringSubmatch(line, -1) {
resultMap[r[1]] = true
}
// log.Println(line)
}
if err := scanner.Err(); err != nil {
log.Printf("读取文件失败: %s\n", err)
}
if err != nil {
panic(err)
}
}
var result []string
for k := range resultMap {
result = append(result, k)
}
return result
}
func getAllServiceName() (result []string) {
dirEntrys, err := os.ReadDir("proto/service")
if err != nil {
panic(err)
}
for _, e := range dirEntrys {
result = append(result, e.Name()[0:len(e.Name())-6])
}
return
}
// 获取项目的参数
func getServiceNameAndProjectName(dir string) (serviceName string, projectName string, projectLastName string) {
ifile, err := ini.Load(dir + "/service_config.ini")
if err != nil {
err = os.MkdirAll("server", 0755)
if err != nil {
log.Println(err)
}
createFileWithPermNotExists(dir+"/service_config.ini", func(f io.Writer) error {
return tpl.ExecuteTemplate(f, "service_config.tpl", nil)
})
panic("必须填写server/service_config.ini文件的项目名称与proto对应名称")
}
ses := ifile.Section("DEFAULT")
if err != nil {
panic(err)
}
key, err := ses.GetKey("SERVICE_NAME")
if err != nil {
panic(err)
}
serviceName = key.String()
if serviceName == "" {
panic("必须填写server/service_config.ini文件的项目名称与proto对应名称")
}
key, err = ses.GetKey("PROJECT_NAME")
if err != nil {
panic(err)
}
projectName = key.String()
if projectName == "" {
panic("必须填写server/service_config.ini文件的ProjectName称与仓库对应名称")
}
paths := strings.Split(projectName, "/")
projectLastName = paths[len(paths)-1]
return
}
// 获取proto文件的父目录
func findProtoParentFolder() (string, error) {
cwd, err := os.Getwd()
if err != nil {
return "", err
}
return findProtoFolderRecursive(cwd)
}
func findProtoFolderRecursive(dir string) (string, error) {
if dir == "/" || dir == "." {
return "", fmt.Errorf("proto folder not found")
}
protoDir := filepath.Join(dir, "proto")
if _, err := os.Stat(protoDir); err == nil {
return dir, nil
}
return findProtoFolderRecursive(filepath.Dir(dir))
}
func parseGoFile(filePath string) (ClientParams []*ClientParam) {
// 创建文件集合
fset := token.NewFileSet()
// 解析文件
file, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments)
if err != nil {
log.Printf("解析文件失败:%v\n", err)
return
}
// 获取包名
// pkgName := file.Name.Name
// log.Println("Package:", pkgName)
var Param *ClientParam = &ClientParam{}
// 遍历文件中的声明
for _, decl := range file.Decls {
if ServiceName, GrpcServiceName, ok := getGrpcFileServiceName(decl); ok {
Param.ServiceName = ServiceName
Param.GrpcServiceName = GrpcServiceName
ClientParams = append(ClientParams, Param)
Param = &ClientParam{}
}
// 检查是否是类型声明
if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.TYPE {
// 检查是否是结构体声明
for _, spec := range genDecl.Specs {
if typeSpec, ok := spec.(*ast.TypeSpec); ok {
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
// 获取结构体名称
// ast.Print(fset, decl)
structName := typeSpec.Name.Name
// log.Println("Struct: ", structName)
if strings.HasSuffix(structName, "Server") {
Param.ServerName = structName[len("Unimplemented") : len(structName)-6]
}
// 遍历结构体的字段列表
for _, field := range structType.Fields.List {
// 获取字段名称
field.End()
// fieldName := field.Names[0].Name
// log.Println("Field:", fieldName)
}
} else if _, ok := typeSpec.Type.(*ast.InterfaceType); ok {
interfaceName := typeSpec.Name.Name
if strings.HasSuffix(interfaceName, "Client") {
Param.ClientName = interfaceName[0 : len(interfaceName)-6]
}
}
}
}
} else if funcDecl, ok := decl.(*ast.FuncDecl); ok {
funcDecl.End()
// 检查是否是函数声明
// funcName := funcDecl.Name.Name
// log.Printf("Function: %s\n", funcName)
// log.Printf("Signature: %s\n", funcDecl.Type.Params)
// Print the function body, if it has one
// if funcDecl.Body != nil {
// log.Printf("Body: \n%s\n", funcDecl.Body.List)
// }
}
}
return
}
type GrpcServerMethod struct {
MethodName string
Params []string
Returns []string
}
type GrpcServerInfo struct {
ServiceName string
StructName string
UnimplementedStructName string
Method []*GrpcServerMethod
}
func ParseGrpcServerInfo(grpcPath string) (infos []*GrpcServerInfo) {
// workerSpaceDir+"/gen/go/service/auth_grpc.pb.go"
// 解析Go源文件
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, grpcPath, nil, parser.AllErrors)
if err != nil {
log.Println("解析文件失败:", err)
os.Exit(1)
}
packageName := node.Name.Name
// 遍历文件中的所有声明
for _, decl := range node.Decls {
// 检查是否是接口声明
if iface, ok := decl.(*ast.GenDecl); ok && iface.Tok == token.TYPE {
for _, spec := range iface.Specs {
// 检查是否是接口类型
if ifaceSpec, ok := spec.(*ast.TypeSpec); ok {
if ifaceType, ok := ifaceSpec.Type.(*ast.InterfaceType); ok {
// 打印接口名
// log.Println("接口名:", ifaceSpec.Name.Name)
if !strings.HasSuffix(ifaceSpec.Name.Name, "Server") {
continue
}
if strings.HasPrefix(ifaceSpec.Name.Name, "Unsafe") {
continue
}
ServiceName := ifaceSpec.Name.Name[0 : len(ifaceSpec.Name.Name)-6]
info := &GrpcServerInfo{
ServiceName: ServiceName,
StructName: ServiceName + "Logic",
}
info.UnimplementedStructName = "service.Unimplemented" + ServiceName + "Server"
infos = append(infos, info)
// 打印接口方法
for _, method := range ifaceType.Methods.List {
if !isUpper(method.Names[0].Name) {
continue
}
m := &GrpcServerMethod{}
info.Method = append(info.Method, m)
// 方法名称
// log.Println("方法名:", method.Names[0].Name)
// MethodName := method.Names[0].Name
m.MethodName = method.Names[0].Name
// 方法参数
if len(method.Type.(*ast.FuncType).Params.List) > 0 {
// log.Println("参数:")
// params := method.Type.(*ast.FuncType).Params
// log.Println(params.NumFields(), params.List)
// log.Println(string(src[method.Pos()-1 : method.End()-1]))
for _, field := range method.Type.(*ast.FuncType).Params.List {
// log.Printf("%s %s\n", field.Names, getTypeString(field.Type, packageName, 0))
m.Params = append(m.Params, getTypeString(field.Type, packageName))
}
} else {
log.Println("无参数")
}
// 方法返回值
if method.Type.(*ast.FuncType).Results != nil {
for _, field := range method.Type.(*ast.FuncType).Results.List {
m.Returns = append(m.Returns, getTypeString(field.Type, packageName))
}
} else {
log.Println("无返回值")
}
// log.Println()
}
}
}
}
}
}
// log.Println(infos)
return
}
func isUpper(name string) bool {
return name[:1] == strings.ToUpper(name[:1])
}
func getTypeString(expr ast.Expr, packageName ...string) string {
if len(packageName) > 0 {
return _getTypeString(expr, &packageName[0], 0)
} else {
return _getTypeString(expr, nil, 0)
}
}
// 获取类型字符串
func _getTypeString(expr ast.Expr, packageName *string, level int) string {
switch t := expr.(type) {
case *ast.Ident:
if level == 0 && isUpper(t.Name) {
if packageName != nil {
return fmt.Sprintf("%s.%s", *packageName, t.Name)
} else {
return fmt.Sprintf("%s", t.Name)
}
}
return t.Name
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", _getTypeString(t.X, packageName, level+1), t.Sel.Name)
case *ast.StarExpr:
return fmt.Sprintf("*%s", _getTypeString(t.X, packageName, level))
case *ast.ArrayType:
return fmt.Sprintf("[]%s", _getTypeString(t.Elt, packageName, level+1))
case *ast.MapType:
return fmt.Sprintf("map[%s]%s", _getTypeString(t.Key, packageName, level+1), _getTypeString(t.Value, packageName, level+1))
case *ast.ChanType:
dir := ""
switch t.Dir {
case ast.SEND:
dir = "chan<-"
case ast.RECV:
dir = "<-chan"
}
return fmt.Sprintf("%s %s", dir, _getTypeString(t.Value, packageName, level+1))
case *ast.StructType:
return "struct{}"
case *ast.InterfaceType:
return "interface{}"
case *ast.FuncType:
var params []string
if t.Params != nil {
for _, param := range t.Params.List {
params = append(params, _getTypeString(param.Type, packageName, level+1))
}
}
var results []string
if t.Results != nil {
for _, result := range t.Results.List {
results = append(results, _getTypeString(result.Type, packageName, level+1))
}
}
return fmt.Sprintf("func(%s) %s", strings.Join(params, ", "), strings.Join(results, ", "))
default:
panic(fmt.Sprintf("%s", t))
}
}
func createFileWithPermNotExists(filename string, do func(f io.Writer) error) error {
// 检测文件是否存在
_, err := os.Stat(filename)
if os.IsNotExist(err) {
// 文件不存在,创建文件
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
var buf = bytes.NewBuffer(nil)
err = do(buf)
if err != nil {
panic(err)
}
data, err := format.Source(buf.Bytes())
if err != nil {
_, err = file.Write(buf.Bytes())
} else {
_, err = file.Write(data)
}
if err != nil {
panic(err)
}
log.Printf("%s 文件已创建并写入内容\n", filename)
} else if err != nil {
// 发生其他错误
return err
} else {
// 文件已存在
// log.Printf("%s 文件已存在\n", filename)
}
return nil
}
func createFileNotExists(filename string, perm fs.FileMode, do func(f io.Writer) error) error {
// 检测文件是否存在
_, err := os.Stat(filename)
if os.IsNotExist(err) {
// 文件不存在,创建文件
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, perm)
if err != nil {
return err
}
defer file.Close()
var buf = bytes.NewBuffer(nil)
err = do(buf)
if err != nil {
panic(err)
}
_, err = file.Write(buf.Bytes())
if err != nil {
panic(err)
}
log.Printf("%s 文件已创建并写入内容\n", filename)
} else if err != nil {
// 发生其他错误
panic(err)
} else {
// 文件已存在
// log.Printf("%s 文件已存在\n", filename)
}
return nil
}
func convertToSnakeCase(name string) string {
var result strings.Builder
for i, char := range name {
if unicode.IsUpper(char) {
if i > 0 {
result.WriteRune('_')
}
result.WriteRune(unicode.ToLower(char))
} else {
result.WriteRune(char)
}
}
return result.String()
}
// 获取当前执行文件绝对路径go run
func getCurrentAbPathByCaller() string {
var abPath string
_, filename, _, ok := runtime.Caller(0)
if ok {
abPath = path.Dir(filename)
}
return abPath
}
func checkProtoMessageName(folderPath string) {
messageNames := make(map[string][]string)
// 遍历文件夹
err := filepath.Walk(folderPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 检查是否是 .proto 文件
if filepath.Ext(path) == ".proto" {
// 读取文件内容
content, err := os.ReadFile(path)
if err != nil {
return err
}
// 提取消息类型的名称
matches := regexp.MustCompile(`message\s+(\w+)\s+{`).FindAllStringSubmatch(string(content), -1)
for _, match := range matches {
messageName := match[1]
if paths, ok := messageNames[messageName]; ok {
// 发现重复的消息类型名称
panic(fmt.Sprintf("重复的消息类型名称 '%s' 在文件 '%s' 和之前的文件 %s 中存在重复。\n", messageName, path, paths[0]))
}
messageNames[messageName] = append(messageNames[messageName], path)
}
}
return nil
})
if err != nil {
panic(fmt.Sprintln("遍历文件夹时发生错误:", err))
}
}
// 下划线转驼峰
func underscoreToLowerCamelCase(s string) string {
// 分割字符串
words := strings.Split(s, "_")
caser := cases.Title(language.English)
for i := range words {
words[i] = caser.String(strings.Trim(words[i], " "))
}
return strings.Join(words, "")
}
func getGrpcFileServiceName(decl ast.Decl) (string, string, bool) {
if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.VAR {
// 遍历变量规范
for _, spec := range genDecl.Specs {
// 检查是否为值规范
if valueSpec, ok := spec.(*ast.ValueSpec); ok {
// 遍历变量名和变量值
for i, name := range valueSpec.Names {
log.Println(name)
value := valueSpec.Values[i]
// ast.Print(fset, value)
// 检查变量类型是否为 struct
// 检查变量类型是否为 struct
if v, ok := value.(*ast.CompositeLit); ok {
// 遍历 struct 的字段
if kv := v.Elts[0].(*ast.KeyValueExpr); ok {
// log.Println(kv.Key)
// log.Println(kv.Value)
if _sname, ok := kv.Value.(*ast.BasicLit); ok {
strlist := strings.Split(strings.Trim(_sname.Value, "\""), ".")
if len(strlist) != 2 {
return "", "", false
}
serviceName, serviceSubName := strlist[0], strlist[1]
return serviceName, serviceSubName, true
}
}
}
}
}
}
}
return "", "", false
}