fusenapi/generator/main.go
2023-08-22 13:43:13 +08:00

562 lines
12 KiB
Go

package main
import (
"database/sql"
"flag"
"fmt"
"log"
"os"
"os/exec"
"regexp"
"sort"
"strconv"
"strings"
"golang.org/x/text/cases"
"golang.org/x/text/language"
_ "github.com/go-sql-driver/mysql"
)
var testName = "fs_auth_item"
var testGenDir = "model/gmodel"
func toPascalCase(s string) string {
words := strings.Split(s, "_")
for i, word := range words {
words[i] = cases.Title(language.English).String(strings.ToLower(word))
}
return strings.Join(words, "")
}
func GetAllTableNames(uri string) []string {
db, err := sql.Open("mysql", uri)
if err != nil {
panic(err)
}
defer db.Close()
rows, err := db.Query("SHOW TABLES")
if err != nil {
panic(err)
}
var tableNames []string
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
panic(err)
}
tableNames = append(tableNames, tableName)
}
return tableNames
}
// "fsreaderwriter:XErSYmLELKMnf3Dh@tcp(fusen.cdmigcvz3rle.us-east-2.rds.amazonaws.com:3306)/fusen"
func GetColsFromTable(tname string, db *sql.DB) (result []Column, tableName, tableComment string) {
var a, ddl string
err := db.QueryRow("SHOW CREATE TABLE "+tname).Scan(&a, &ddl)
// log.Println(ddl)
if err != nil {
panic(err)
}
return ParserDDL(ddl)
}
var gmodelVarStr = `
package gmodel
import "gorm.io/gorm"
// AllModelsGen 所有Model集合,修改单行,只要不改字段名,不会根据新的内容修改,需要修改的话手动删除
type AllModelsGen struct {
}
func NewAllModels(gdb *gorm.DB) *AllModelsGen {
models := &AllModelsGen{
}
return models
}
`
var gmodelVarStrFormat = `
package gmodel
import "gorm.io/gorm"
// AllModelsGen 所有Model集合,修改单行,只要不改字段名,不会根据新的内容修改,需要修改的话手动删除
type AllModelsGen struct {
%s
}
func NewAllModels(gdb *gorm.DB) *AllModelsGen {
models := &AllModelsGen{
%s
}
return models
}
`
type TableNameComment struct {
Name string
GoName string
Comment string
}
type TMCS []TableNameComment
func (u TMCS) Len() int {
return len(u)
}
func (u TMCS) Less(i, j int) bool {
return u[i].Name < u[j].Name
}
func (u TMCS) Swap(i, j int) {
u[i], u[j] = u[j], u[i]
}
func GenAllModels(filedir string, tmcs ...TableNameComment) {
fileName := filedir + "/var_gen.go"
var dupMap map[string]TableNameComment = make(map[string]TableNameComment)
for _, tmc := range tmcs {
dupMap[tmc.Name] = tmc
}
if _, err := os.Stat(fileName); err == nil {
log.Printf("%s exists!", fileName)
data, err := os.ReadFile(fileName)
if err != nil {
panic(err)
}
filestr := string(data)
filelines := strings.Split(filestr, "\n")
re := regexp.MustCompile(`([A-Za-z0-9_]+) [^/]+ // ([^ ]+) (.+)$`)
for _, line := range filelines {
result := re.FindStringSubmatch(line)
if len(result) > 0 {
// key := result[0]
if len(result) != 4 {
log.Println(result)
}
log.Println(result)
tmc := TableNameComment{
Name: result[2],
GoName: result[1],
Comment: result[3],
}
if newTmc, ok := dupMap[tmc.Name]; ok {
log.Printf("not change: (old)%v -> (new)%v", tmc, newTmc)
}
dupMap[tmc.Name] = tmc
}
}
tmcs = nil
for _, tmc := range dupMap {
tmcs = append(tmcs, tmc)
}
sort.Sort(TMCS(tmcs))
structStr := ""
newModelsStr := ""
for _, tmc := range tmcs {
fsline := fmt.Sprintf("%s *%sModel // %s %s\n", tmc.GoName, tmc.GoName, tmc.Name, tmc.Comment)
structStr += fsline
nmline := fmt.Sprintf("%s: New%sModel(gdb),\n", tmc.GoName, tmc.GoName)
newModelsStr += nmline
}
content := fmt.Sprintf(gmodelVarStrFormat, structStr, newModelsStr)
f, err := os.OpenFile(fileName, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
panic(err)
}
_, err = f.WriteString(content)
if err != nil {
panic(err)
}
} else if os.IsExist(err) {
f, err := os.Create(fileName)
if err != nil {
panic(err)
}
_, err = f.WriteString(gmodelVarStr)
if err != nil {
panic(err)
}
} else {
panic(err)
}
err := exec.Command("gofmt", "-w", fileName).Run()
if err != nil {
panic(err)
}
}
func main() {
var mysqluri string
var name string // 需要序列化的单独文件名
var mdir string // 需要修改的序列化路径 model
flag.StringVar(&mysqluri, "uri", "fsreaderwriter:XErSYmLELKMnf3Dh@tcp(fusen.cdmigcvz3rle.us-east-2.rds.amazonaws.com:3306)/fusen", "输入需要序列化的ddl文件名, 不需要后缀.ddl")
flag.StringVar(&name, "name", "", "输入需要序列化的ddl文件名, 不需要后缀.ddl")
flag.StringVar(&mdir, "mdir", "", "输入需要生成model的Go文件所在目录")
flag.Parse()
if mdir != "" {
testGenDir = mdir
}
db, err := sql.Open("mysql", mysqluri)
if err != nil {
panic(err)
}
defer db.Close()
var tmcs []TableNameComment
if name == "-" {
tablenames := GetAllTableNames(mysqluri)
for _, testName := range tablenames {
cols, tname, tcomment := GetColsFromTable(testName, db)
GenFromPath(testGenDir, cols, tname, tcomment)
tmcs = append(tmcs, TableNameComment{
Name: tname,
GoName: toPascalCase(tname),
Comment: tcomment,
})
}
} else {
if name != "" {
testName = name
}
// log.Println(testName)
cols, tname, tcomment := GetColsFromTable(testName, db)
GenFromPath(testGenDir, cols, tname, tcomment)
tmcs = append(tmcs, TableNameComment{
Name: tname,
GoName: toPascalCase(tname),
Comment: tcomment,
})
}
GenAllModels(testGenDir, tmcs...)
// tablenames := GetAllTableNames(mysqluri)
// log.Println(tablenames)
// name
}
func GenFromPath(mdir string, cols []Column, tableName string, tableComment string) {
var importstr = "import (\"gorm.io/gorm\"\n"
// 匹配到主键定义
fcontent := "package gmodel\n"
structstr := "// %s %s\ntype %s struct {%s\n}\n"
pTableName := toPascalCase(tableName)
fieldstr := ""
for _, col := range cols {
fieldName := toPascalCase(col.Name)
typeName := typeForMysqlToGo[col.GetType()]
var defaultString string
if col.DefaultValue != nil {
switch typeName {
case "*int64", "*uint64", "*float64", "*bool":
defaultString = "default:" + strings.Trim(*col.DefaultValue, "'") + ";"
default:
defaultString = "default:" + *col.DefaultValue + ";"
}
} else {
switch typeName {
case "*string":
defaultString = "default:'';"
case "*time.Time":
defaultString = "default:'0000-00-00 00:00:00';"
case "*[]byte":
defaultString = "default:'';"
case "*int64", "*uint64":
defaultString = "default:0;"
case "*float64":
defaultString = "default: 0.0;"
case "*bool":
defaultString = "default:0;"
default:
fieldName = "// " + fieldName + " " + col.Type
}
}
if typeName == "*time.Time" {
importstr += "\"time\"\n"
}
if col.IndexType == "primary_key" {
typeName = typeName[1:]
}
tagstr := "`gorm:"
gormTag := ""
if col.IndexType != "" {
gormTag += col.IndexType + ";"
}
gormTag += defaultString
if col.AutoIncrement {
gormTag += "auto_increment;"
}
tagstr += fmt.Sprintf("\"%s\"", gormTag)
tagstr += fmt.Sprintf(" json:\"%s\"`", col.Name)
fieldColStr := fmt.Sprintf("\n%s %s %s// %s", fieldName, typeName, tagstr, col.Comment)
fieldstr += fieldColStr
}
fcontent += importstr + ")\n"
fcontent += fmt.Sprintf(structstr, tableName, tableComment, pTableName, fieldstr)
modelstr := fmt.Sprintf(`type %sModel struct {db *gorm.DB
name string}`, pTableName)
fcontent += modelstr
fcontent += "\n"
newfuncstr := fmt.Sprintf(`func New%sModel(db *gorm.DB) *%sModel {return &%sModel{db:db,name:"%s"}}`, pTableName, pTableName, pTableName, tableName)
fcontent += newfuncstr
fcontent += "\n"
genGoFileName := fmt.Sprintf("%s/%s_gen.go", mdir, tableName)
f, err := os.OpenFile(genGoFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
panic(err)
}
f.WriteString(fcontent)
err = f.Close()
if err != nil {
panic(err)
}
err = exec.Command("gofmt", "-w", genGoFileName).Run()
if err != nil {
panic(err)
}
fcontent = "package gmodel\n// TODO: 使用model的属性做你想做的"
genGoLogicFileName := fmt.Sprintf("%s/%s_logic.go", mdir, tableName)
// 使用 os.Stat 函数获取文件信息
_, err = os.Stat(genGoLogicFileName)
// 判断文件是否存在并输出结果
if os.IsNotExist(err) {
f2, err := os.OpenFile(genGoLogicFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
panic(err)
}
f2.WriteString(fcontent)
err = f2.Close()
if err != nil {
panic(err)
}
fmt.Println(genGoLogicFileName, "create!")
} else {
fmt.Println(genGoLogicFileName, "exists")
}
}
type Column struct {
Name string
Type string
DefaultValue *string
Length int
Decimal int
Unsigned bool
NotNull bool
AutoIncrement bool
Comment string
IndexType string
}
func (col *Column) GetType() string {
content := col.Type
if col.Unsigned {
return content + " unsigned"
}
return content
}
var typeForMysqlToGo = map[string]string{
// 整数
"int": "*int64",
"integer": "*int64",
"tinyint": "*int64",
"smallint": "*int64",
"mediumint": "*int64",
"bigint": "*int64",
"year": "*int64",
"int unsigned": "*int64",
"integer unsigned": "*int64",
"tinyint unsigned": "*int64",
"smallint unsigned": "*int64",
"mediumint unsigned": "*int64",
"bigint unsigned": "*int64",
"bit": "*int64",
// 布尔类型
"bool": "*bool",
// 字符串
"enum": "*string",
"set": "*string",
"varchar": "*string",
"char": "*string",
"tinytext": "*string",
"mediumtext": "*string",
"text": "*string",
"longtext": "*string",
// 二进制
"binary": "*[]byte",
"varbinary": "*[]byte",
"blob": "*[]byte",
"tinyblob": "*[]byte",
"mediumblob": "*[]byte",
"longblob": "*[]byte",
// 日期时间
"date": "*time.Time",
"datetime": "*time.Time",
"timestamp": "*time.Time",
"time": "*time.Time",
// 浮点数
"float": "*float64",
"double": "*float64",
"decimal": "*float64",
}
func ParserDDL(ddl string) (result []Column, tableName, tableComment string) {
reTable := regexp.MustCompile(`CREATE TABLE +([^ ]+) +\(`)
reTableComment := regexp.MustCompile(`.+COMMENT='(.+)'$`)
reField := regexp.MustCompile("`([^`]+)` +([^ \n\\(\\,]+)(?:\\(([^)]+)\\))?( +unsigned| +UNSIGNED)?( +not +null| +NOT +NULL)?( +default +\\'[^\\']*'| +DEFAULT +\\'[^\\']*')?( +auto_increment| +AUTO_INCREMENT)?( comment '[^']*'| COMMENT '[^']*')?(,)?")
reIndex := regexp.MustCompile(`(?i)(PRIMARY|UNIQUE)?\s*(INDEX|KEY)\s*(` + "`([^`]*)`" + `)?\s*\(([^)]+)\)`)
reValue := regexp.MustCompile(` '(.+)'$`)
reDefaultValue := regexp.MustCompile(` ('.+')$`)
var fieldmap map[string]string = make(map[string]string)
indexMatches := reIndex.FindAllStringSubmatch(ddl, -1)
for _, m := range indexMatches {
idxAttr := strings.Trim(m[5], "`")
PrefixName := strings.ToUpper(m[1])
if PrefixName == "PRIMARY" {
fieldmap[idxAttr] = "primary_key"
} else if PrefixName == "UNIQUE" {
fieldmap[idxAttr] = "unique_key"
} else if PrefixName == "" {
fieldmap[idxAttr] = "index"
} else {
log.Fatal(PrefixName)
}
}
tableMatches := reTable.FindStringSubmatch(ddl)
tableName = strings.Trim(tableMatches[1], "`")
tableCommentMatches := reTableComment.FindStringSubmatch(ddl)
if len(tableCommentMatches) > 0 {
tableComment = strings.Trim(tableCommentMatches[1], "`")
}
// log.Println(tableName, tableComment)
fieldMatches := reField.FindAllStringSubmatch(ddl, -1)
for _, m := range fieldMatches {
if m[0] == "" {
continue
}
col := Column{
Name: m[1],
Type: strings.ToLower(m[2]),
}
col.IndexType = fieldmap[col.Name]
if m[3] != "" {
maylen := strings.Split(m[3], ",")
if len(maylen) >= 1 {
clen, err := strconv.ParseInt(maylen[0], 10, 64)
if err != nil {
panic(err)
}
col.Length = int(clen)
}
if len(maylen) >= 2 {
clen, err := strconv.ParseInt(maylen[1], 10, 64)
if err != nil {
panic(err)
}
col.Decimal = int(clen)
}
}
if len(m[4]) > 0 {
col.Unsigned = true
}
if len(m[5]) > 0 {
col.NotNull = true
}
if len(m[6]) > 0 {
v := reDefaultValue.FindStringSubmatch(m[6])
if len(v) > 0 {
dv := string(v[1])
col.DefaultValue = &dv
}
}
if len(m[7]) > 0 {
col.AutoIncrement = true
}
if len(m[8]) > 0 {
v := reValue.FindStringSubmatch(m[8])
if len(v) > 0 {
col.Comment = v[1]
}
}
result = append(result, col)
// fmt.Println(col)
}
return result, tableName, tableComment
}