123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- package main
- import (
- "database/sql"
- "flag"
- "fmt"
- "log"
- "os"
- "strings"
- "github.com/xinliangnote/go-gin-api/cmd/mysqlmd/mysql"
- "gorm.io/gorm"
- )
- type tableInfo struct {
- Name string `db:"table_name"` // name
- Comment sql.NullString `db:"table_comment"` // comment
- }
- type tableColumn struct {
- OrdinalPosition uint16 `db:"ORDINAL_POSITION"` // position
- ColumnName string `db:"COLUMN_NAME"` // name
- ColumnType string `db:"COLUMN_TYPE"` // column_type
- DataType string `db:"DATA_TYPE"` // data_type
- ColumnKey sql.NullString `db:"COLUMN_KEY"` // key
- IsNullable string `db:"IS_NULLABLE"` // nullable
- Extra sql.NullString `db:"EXTRA"` // extra
- ColumnComment sql.NullString `db:"COLUMN_COMMENT"` // comment
- ColumnDefault sql.NullString `db:"COLUMN_DEFAULT"` // default value
- }
- var (
- dbAddr string
- dbUser string
- dbPass string
- dbName string
- genTables string
- )
- func init() {
- addr := flag.String("addr", "", "请输入 db 地址,例如:127.0.0.1:3306\n")
- user := flag.String("user", "", "请输入 db 用户名\n")
- pass := flag.String("pass", "", "请输入 db 密码\n")
- name := flag.String("name", "", "请输入 db 名称\n")
- table := flag.String("tables", "*", "请输入 table 名称,默认为“*”,多个可用“,”分割\n")
- flag.Parse()
- dbAddr = *addr
- dbUser = *user
- dbPass = *pass
- dbName = strings.ToLower(*name)
- genTables = strings.ToLower(*table)
- }
- func main() {
- // 初始化 DB
- db, err := mysql.New(dbAddr, dbUser, dbPass, dbName)
- if err != nil {
- log.Fatal("new db err", err)
- }
- defer func() {
- if err := db.DbClose(); err != nil {
- log.Println("db close err", err)
- }
- }()
- tables, err := queryTables(db.GetDb(), dbName, genTables)
- if err != nil {
- log.Println("query tables of database err", err)
- return
- }
- for _, table := range tables {
- filepath := "./internal/repository/mysql/" + table.Name
- _ = os.Mkdir(filepath, 0766)
- fmt.Println("create dir : ", filepath)
- mdName := fmt.Sprintf("%s/gen_table.md", filepath)
- mdFile, err := os.OpenFile(mdName, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0766)
- if err != nil {
- fmt.Printf("markdown file error %v\n", err.Error())
- return
- }
- fmt.Println(" └── file : ", table.Name+"/gen_table.md")
- modelName := fmt.Sprintf("%s/gen_model.go", filepath)
- modelFile, err := os.OpenFile(modelName, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0766)
- if err != nil {
- fmt.Printf("create and open model file error %v\n", err.Error())
- return
- }
- fmt.Println(" └── file : ", table.Name+"/gen_model.go")
- modelContent := fmt.Sprintf("package %s\n", table.Name)
- modelContent += fmt.Sprintf(`import "time"`)
- modelContent += fmt.Sprintf("\n\n// %s %s \n", capitalize(table.Name), table.Comment.String)
- modelContent += fmt.Sprintf("//go:generate gormgen -structs %s -input . \n", capitalize(table.Name))
- modelContent += fmt.Sprintf("type %s struct {\n", capitalize(table.Name))
- tableContent := fmt.Sprintf("#### %s.%s \n", dbName, table.Name)
- if table.Comment.String != "" {
- tableContent += table.Comment.String + "\n"
- }
- tableContent += "\n" +
- "| 序号 | 名称 | 描述 | 类型 | 键 | 为空 | 额外 | 默认值 |\n" +
- "| :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: |\n"
- columnInfo, columnInfoErr := queryTableColumn(db.GetDb(), dbName, table.Name)
- if columnInfoErr != nil {
- continue
- }
- for _, info := range columnInfo {
- tableContent += fmt.Sprintf(
- "| %d | %s | %s | %s | %s | %s | %s | %s |\n",
- info.OrdinalPosition,
- info.ColumnName,
- strings.ReplaceAll(strings.ReplaceAll(info.ColumnComment.String, "|", "\\|"), "\n", ""),
- info.ColumnType,
- info.ColumnKey.String,
- info.IsNullable,
- info.Extra.String,
- info.ColumnDefault.String,
- )
- if textType(info.DataType) == "time.Time" {
- modelContent += fmt.Sprintf("%s %s `%s` // %s\n", capitalize(info.ColumnName), textType(info.DataType), "gorm:\"time\"", info.ColumnComment.String)
- } else {
- modelContent += fmt.Sprintf("%s %s // %s\n", capitalize(info.ColumnName), textType(info.DataType), info.ColumnComment.String)
- }
- }
- mdFile.WriteString(tableContent)
- mdFile.Close()
- modelContent += "}\n"
- modelFile.WriteString(modelContent)
- modelFile.Close()
- }
- }
- func queryTables(db *gorm.DB, dbName string, tableName string) ([]tableInfo, error) {
- var tableCollect []tableInfo
- var tableArray []string
- var commentArray []sql.NullString
- sqlTables := fmt.Sprintf("SELECT `table_name`,`table_comment` FROM `information_schema`.`tables` WHERE `table_schema`= '%s'", dbName)
- rows, err := db.Raw(sqlTables).Rows()
- if err != nil {
- return tableCollect, err
- }
- defer rows.Close()
- for rows.Next() {
- var info tableInfo
- err = rows.Scan(&info.Name, &info.Comment)
- if err != nil {
- fmt.Printf("execute query tables action error,had ignored, detail is [%v]\n", err.Error())
- continue
- }
- tableCollect = append(tableCollect, info)
- tableArray = append(tableArray, info.Name)
- commentArray = append(commentArray, info.Comment)
- }
- // filter tables when specified tables params
- if tableName != "*" {
- tableCollect = nil
- chooseTables := strings.Split(tableName, ",")
- indexMap := make(map[int]int)
- for _, item := range chooseTables {
- subIndexMap := getTargetIndexMap(tableArray, item)
- for k, v := range subIndexMap {
- if _, ok := indexMap[k]; ok {
- continue
- }
- indexMap[k] = v
- }
- }
- if len(indexMap) != 0 {
- for _, v := range indexMap {
- var info tableInfo
- info.Name = tableArray[v]
- info.Comment = commentArray[v]
- tableCollect = append(tableCollect, info)
- }
- }
- }
- return tableCollect, err
- }
- func queryTableColumn(db *gorm.DB, dbName string, tableName string) ([]tableColumn, error) {
- // 定义承载列信息的切片
- var columns []tableColumn
- sqlTableColumn := fmt.Sprintf("SELECT `ORDINAL_POSITION`,`COLUMN_NAME`,`COLUMN_TYPE`,`DATA_TYPE`,`COLUMN_KEY`,`IS_NULLABLE`,`EXTRA`,`COLUMN_COMMENT`,`COLUMN_DEFAULT` FROM `information_schema`.`columns` WHERE `table_schema`= '%s' AND `table_name`= '%s' ORDER BY `ORDINAL_POSITION` ASC",
- dbName, tableName)
- rows, err := db.Raw(sqlTableColumn).Rows()
- if err != nil {
- fmt.Printf("execute query table column action error, detail is [%v]\n", err.Error())
- return columns, err
- }
- defer rows.Close()
- for rows.Next() {
- var column tableColumn
- err = rows.Scan(
- &column.OrdinalPosition,
- &column.ColumnName,
- &column.ColumnType,
- &column.DataType,
- &column.ColumnKey,
- &column.IsNullable,
- &column.Extra,
- &column.ColumnComment,
- &column.ColumnDefault)
- if err != nil {
- fmt.Printf("query table column scan error, detail is [%v]\n", err.Error())
- return columns, err
- }
- columns = append(columns, column)
- }
- return columns, err
- }
- func getTargetIndexMap(tableNameArr []string, item string) map[int]int {
- indexMap := make(map[int]int)
- for i := 0; i < len(tableNameArr); i++ {
- if tableNameArr[i] == item {
- if _, ok := indexMap[i]; ok {
- continue
- }
- indexMap[i] = i
- }
- }
- return indexMap
- }
- func capitalize(s string) string {
- var upperStr string
- chars := strings.Split(s, "_")
- for _, val := range chars {
- vv := []rune(val)
- for i := 0; i < len(vv); i++ {
- if i == 0 {
- if vv[i] >= 97 && vv[i] <= 122 {
- vv[i] -= 32
- }
- upperStr += string(vv[i])
- } else {
- upperStr += string(vv[i])
- }
- }
- }
- return upperStr
- }
- func textType(s string) string {
- var mysqlTypeToGoType = map[string]string{
- "tinyint": "int32",
- "smallint": "int32",
- "mediumint": "int32",
- "int": "int32",
- "integer": "int64",
- "bigint": "int64",
- "float": "float64",
- "double": "float64",
- "decimal": "float64",
- "date": "string",
- "time": "string",
- "year": "string",
- "datetime": "time.Time",
- "timestamp": "time.Time",
- "char": "string",
- "varchar": "string",
- "tinyblob": "string",
- "tinytext": "string",
- "blob": "string",
- "text": "string",
- "mediumblob": "string",
- "mediumtext": "string",
- "longblob": "string",
- "longtext": "string",
- }
- return mysqlTypeToGoType[s]
- }
|