package main import ( "database/sql" "flag" "fmt" "log" "os" "os/exec" "regexp" "sort" "strconv" "strings" "golang.org/x/text/cases" "golang.org/x/text/language" _ "github.com/go-sql-driver/mysql" ) var testName = "fs_auth_item" var testGenDir = "model/gmodel" func toPascalCase(s string) string { words := strings.Split(s, "_") for i, word := range words { words[i] = cases.Title(language.English).String(strings.ToLower(word)) } return strings.Join(words, "") } func GetAllTableNames(uri string) []string { db, err := sql.Open("mysql", uri) if err != nil { panic(err) } defer db.Close() rows, err := db.Query("SHOW TABLES") if err != nil { panic(err) } var tableNames []string for rows.Next() { var tableName string if err := rows.Scan(&tableName); err != nil { panic(err) } tableNames = append(tableNames, tableName) } return tableNames } // "fusentest:XErSYmLELKMnf3Dh@tcp(110.41.19.98:3306)/fusentest" func GetColsFromTable(tname string, db *sql.DB) (result []Column, tableName, tableComment string) { var a, ddl string err := db.QueryRow("SHOW CREATE TABLE "+tname).Scan(&a, &ddl) // log.Println(ddl) if err != nil { panic(err) } return ParserDDL(ddl) } var gmodelVarStr = ` package gmodel import "gorm.io/gorm" // AllModelsGen 所有Model集合,修改单行,只要不改字段名,不会根据新的内容修改,需要修改的话手动删除 type AllModelsGen struct { } func NewAllModels(gdb *gorm.DB) *AllModelsGen { models := &AllModelsGen{ } return models } ` var gmodelVarStrFormat = ` package gmodel import "gorm.io/gorm" // AllModelsGen 所有Model集合,修改单行,只要不改字段名,不会根据新的内容修改,需要修改的话手动删除 type AllModelsGen struct { %s } func NewAllModels(gdb *gorm.DB) *AllModelsGen { models := &AllModelsGen{ %s } return models } ` type TableNameComment struct { Name string GoName string Comment string } type TMCS []TableNameComment func (u TMCS) Len() int { return len(u) } func (u TMCS) Less(i, j int) bool { return u[i].Name < u[j].Name } func (u TMCS) Swap(i, j int) { u[i], u[j] = u[j], u[i] } func GenAllModels(filedir string, tmcs ...TableNameComment) { fileName := filedir + "/var_gen.go" var dupMap map[string]TableNameComment = make(map[string]TableNameComment) for _, tmc := range tmcs { dupMap[tmc.Name] = tmc } if _, err := os.Stat(fileName); err == nil { log.Printf("%s exists!", fileName) data, err := os.ReadFile(fileName) if err != nil { panic(err) } filestr := string(data) filelines := strings.Split(filestr, "\n") re := regexp.MustCompile(`([A-Za-z0-9_]+) [^/]+ // ([^ ]+) (.+)$`) for _, line := range filelines { result := re.FindStringSubmatch(line) if len(result) > 0 { // key := result[0] if len(result) != 4 { log.Println(result) } log.Println(result) tmc := TableNameComment{ Name: result[2], GoName: result[1], Comment: result[3], } if newTmc, ok := dupMap[tmc.Name]; ok { log.Printf("not change: (old)%v -> (new)%v", tmc, newTmc) } dupMap[tmc.Name] = tmc } } tmcs = nil for _, tmc := range dupMap { tmcs = append(tmcs, tmc) } sort.Sort(TMCS(tmcs)) structStr := "" newModelsStr := "" for _, tmc := range tmcs { fsline := fmt.Sprintf("%s *%sModel // %s %s\n", tmc.GoName, tmc.GoName, tmc.Name, tmc.Comment) structStr += fsline nmline := fmt.Sprintf("%s: New%sModel(gdb),\n", tmc.GoName, tmc.GoName) newModelsStr += nmline } content := fmt.Sprintf(gmodelVarStrFormat, structStr, newModelsStr) f, err := os.OpenFile(fileName, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { panic(err) } _, err = f.WriteString(content) if err != nil { panic(err) } } else if os.IsExist(err) { f, err := os.Create(fileName) if err != nil { panic(err) } _, err = f.WriteString(gmodelVarStr) if err != nil { panic(err) } } else { panic(err) } err := exec.Command("gofmt", "-w", fileName).Run() if err != nil { panic(err) } } func main() { var mysqluri string var name string // 需要序列化的单独文件名 var mdir string // 需要修改的序列化路径 model flag.StringVar(&mysqluri, "uri", "fusentest:XErSYmLELKMnf3Dh@tcp(110.41.19.98:3306)/fusentest", "输入需要序列化的ddl文件名, 不需要后缀.ddl") flag.StringVar(&name, "name", "", "输入需要序列化的ddl文件名, 不需要后缀.ddl") flag.StringVar(&mdir, "mdir", "", "输入需要生成model的Go文件所在目录") flag.Parse() if mdir != "" { testGenDir = mdir } db, err := sql.Open("mysql", mysqluri) if err != nil { panic(err) } defer db.Close() var tmcs []TableNameComment if name == "-" { tablenames := GetAllTableNames(mysqluri) for _, testName := range tablenames { cols, tname, tcomment := GetColsFromTable(testName, db) GenFromPath(testGenDir, cols, tname, tcomment) tmcs = append(tmcs, TableNameComment{ Name: tname, GoName: toPascalCase(tname), Comment: tcomment, }) } } else { if name != "" { testName = name } // log.Println(testName) cols, tname, tcomment := GetColsFromTable(testName, db) GenFromPath(testGenDir, cols, tname, tcomment) tmcs = append(tmcs, TableNameComment{ Name: tname, GoName: toPascalCase(tname), Comment: tcomment, }) } GenAllModels(testGenDir, tmcs...) // tablenames := GetAllTableNames(mysqluri) // log.Println(tablenames) // name } func GenFromPath(mdir string, cols []Column, tableName string, tableComment string) { var importstr = "import (\"gorm.io/gorm\"\n" // 匹配到主键定义 fcontent := "package gmodel\n" structstr := "// %s %s\ntype %s struct {%s\n}\n" pTableName := toPascalCase(tableName) fieldstr := "" for _, col := range cols { fieldName := toPascalCase(col.Name) typeName := typeForMysqlToGo[col.GetType()] var defaultString string if col.DefaultValue != nil { switch typeName { case "*int64", "*uint64", "*float64", "*bool": defaultString = "default:" + strings.Trim(*col.DefaultValue, "'") + ";" default: defaultString = "default:" + *col.DefaultValue + ";" } } else { switch typeName { case "*string": defaultString = "default:'';" case "*time.Time": defaultString = "default:'0000-00-00 00:00:00';" case "*[]byte": defaultString = "default:'';" case "*int64", "*uint64": defaultString = "default:0;" case "*float64": defaultString = "default: 0.0;" case "*bool": defaultString = "default:0;" default: fieldName = "// " + fieldName + " " + col.Type } } if typeName == "*time.Time" { importstr += "\"time\"\n" } if col.IndexType == "primary_key" { typeName = typeName[1:] } tagstr := "`gorm:" gormTag := "" if col.IndexType != "" { gormTag += col.IndexType + ";" } gormTag += defaultString if col.AutoIncrement { gormTag += "auto_increment;" } tagstr += fmt.Sprintf("\"%s\"", gormTag) tagstr += fmt.Sprintf(" json:\"%s\"`", col.Name) fieldColStr := fmt.Sprintf("\n%s %s %s// %s", fieldName, typeName, tagstr, col.Comment) fieldstr += fieldColStr } fcontent += importstr + ")\n" fcontent += fmt.Sprintf(structstr, tableName, tableComment, pTableName, fieldstr) modelstr := fmt.Sprintf(`type %sModel struct {db *gorm.DB name string}`, pTableName) fcontent += modelstr fcontent += "\n" newfuncstr := fmt.Sprintf(`func New%sModel(db *gorm.DB) *%sModel {return &%sModel{db:db,name:"%s"}}`, pTableName, pTableName, pTableName, tableName) fcontent += newfuncstr fcontent += "\n" genGoFileName := fmt.Sprintf("%s/%s_gen.go", mdir, tableName) f, err := os.OpenFile(genGoFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { panic(err) } f.WriteString(fcontent) err = f.Close() if err != nil { panic(err) } err = exec.Command("gofmt", "-w", genGoFileName).Run() if err != nil { panic(err) } fcontent = "package gmodel\n// TODO: 使用model的属性做你想做的" genGoLogicFileName := fmt.Sprintf("%s/%s_logic.go", mdir, tableName) // 使用 os.Stat 函数获取文件信息 _, err = os.Stat(genGoLogicFileName) // 判断文件是否存在并输出结果 if os.IsNotExist(err) { f2, err := os.OpenFile(genGoLogicFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { panic(err) } f2.WriteString(fcontent) err = f2.Close() if err != nil { panic(err) } fmt.Println(genGoLogicFileName, "create!") } else { fmt.Println(genGoLogicFileName, "exists") } } type Column struct { Name string Type string DefaultValue *string Length int Decimal int Unsigned bool NotNull bool AutoIncrement bool Comment string IndexType string } func (col *Column) GetType() string { content := col.Type if col.Unsigned { return content + " unsigned" } return content } var typeForMysqlToGo = map[string]string{ // 整数 "int": "*int64", "integer": "*int64", "tinyint": "*int64", "smallint": "*int64", "mediumint": "*int64", "bigint": "*int64", "year": "*int64", "int unsigned": "*int64", "integer unsigned": "*int64", "tinyint unsigned": "*int64", "smallint unsigned": "*int64", "mediumint unsigned": "*int64", "bigint unsigned": "*int64", "bit": "*int64", // 布尔类型 "bool": "*bool", // 字符串 "enum": "*string", "set": "*string", "varchar": "*string", "char": "*string", "tinytext": "*string", "mediumtext": "*string", "text": "*string", "longtext": "*string", // 二进制 "binary": "*[]byte", "varbinary": "*[]byte", "blob": "*[]byte", "tinyblob": "*[]byte", "mediumblob": "*[]byte", "longblob": "*[]byte", // 日期时间 "date": "*time.Time", "datetime": "*time.Time", "timestamp": "*time.Time", "time": "*time.Time", // 浮点数 "float": "*float64", "double": "*float64", "decimal": "*float64", } func ParserDDL(ddl string) (result []Column, tableName, tableComment string) { reTable := regexp.MustCompile(`CREATE TABLE +([^ ]+) +\(`) reTableComment := regexp.MustCompile(`.+COMMENT='(.+)'$`) reField := regexp.MustCompile("`([^`]+)` +([^ \n\\(\\,]+)(?:\\(([^)]+)\\))?( +unsigned| +UNSIGNED)?( +not +null| +NOT +NULL)?( +default +\\'[^\\']*'| +DEFAULT +\\'[^\\']*')?( +auto_increment| +AUTO_INCREMENT)?( comment '[^']*'| COMMENT '[^']*')?(,)?") reIndex := regexp.MustCompile(`(?i)(PRIMARY|UNIQUE)?\s*(INDEX|KEY)\s*(` + "`([^`]*)`" + `)?\s*\(([^)]+)\)`) reValue := regexp.MustCompile(` '(.+)'$`) reDefaultValue := regexp.MustCompile(` ('.+')$`) var fieldmap map[string]string = make(map[string]string) indexMatches := reIndex.FindAllStringSubmatch(ddl, -1) for _, m := range indexMatches { idxAttr := strings.Trim(m[5], "`") PrefixName := strings.ToUpper(m[1]) if PrefixName == "PRIMARY" { fieldmap[idxAttr] = "primary_key" } else if PrefixName == "UNIQUE" { fieldmap[idxAttr] = "unique_key" } else if PrefixName == "" { fieldmap[idxAttr] = "index" } else { log.Fatal(PrefixName) } } tableMatches := reTable.FindStringSubmatch(ddl) tableName = strings.Trim(tableMatches[1], "`") tableCommentMatches := reTableComment.FindStringSubmatch(ddl) if len(tableCommentMatches) > 0 { tableComment = strings.Trim(tableCommentMatches[1], "`") } // log.Println(tableName, tableComment) fieldMatches := reField.FindAllStringSubmatch(ddl, -1) for _, m := range fieldMatches { if m[0] == "" { continue } col := Column{ Name: m[1], Type: strings.ToLower(m[2]), } col.IndexType = fieldmap[col.Name] if m[3] != "" { maylen := strings.Split(m[3], ",") if len(maylen) >= 1 { clen, err := strconv.ParseInt(maylen[0], 10, 64) if err != nil { panic(err) } col.Length = int(clen) } if len(maylen) >= 2 { clen, err := strconv.ParseInt(maylen[1], 10, 64) if err != nil { panic(err) } col.Decimal = int(clen) } } if len(m[4]) > 0 { col.Unsigned = true } if len(m[5]) > 0 { col.NotNull = true } if len(m[6]) > 0 { v := reDefaultValue.FindStringSubmatch(m[6]) if len(v) > 0 { dv := string(v[1]) col.DefaultValue = &dv } } if len(m[7]) > 0 { col.AutoIncrement = true } if len(m[8]) > 0 { v := reValue.FindStringSubmatch(m[8]) if len(v) > 0 { col.Comment = v[1] } } result = append(result, col) // fmt.Println(col) } return result, tableName, tableComment }