main.go 8.1 KB


  1. package main
  2. import (
  3. "database/sql"
  4. "flag"
  5. "fmt"
  6. "log"
  7. "os"
  8. "strings"
  9. "github.com/xinliangnote/go-gin-api/cmd/mysqlmd/mysql"
  10. "gorm.io/gorm"
  11. )
  12. type tableInfo struct {
  13. Name string `db:"table_name"` // name
  14. Comment sql.NullString `db:"table_comment"` // comment
  15. }
  16. type tableColumn struct {
  17. OrdinalPosition uint16 `db:"ORDINAL_POSITION"` // position
  18. ColumnName string `db:"COLUMN_NAME"` // name
  19. ColumnType string `db:"COLUMN_TYPE"` // column_type
  20. DataType string `db:"DATA_TYPE"` // data_type
  21. ColumnKey sql.NullString `db:"COLUMN_KEY"` // key
  22. IsNullable string `db:"IS_NULLABLE"` // nullable
  23. Extra sql.NullString `db:"EXTRA"` // extra
  24. ColumnComment sql.NullString `db:"COLUMN_COMMENT"` // comment
  25. ColumnDefault sql.NullString `db:"COLUMN_DEFAULT"` // default value
  26. }
  27. var (
  28. dbAddr string
  29. dbUser string
  30. dbPass string
  31. dbName string
  32. genTables string
  33. )
  34. func init() {
  35. addr := flag.String("addr", "", "请输入 db 地址,例如:127.0.0.1:3306\n")
  36. user := flag.String("user", "", "请输入 db 用户名\n")
  37. pass := flag.String("pass", "", "请输入 db 密码\n")
  38. name := flag.String("name", "", "请输入 db 名称\n")
  39. table := flag.String("tables", "*", "请输入 table 名称,默认为“*”,多个可用“,”分割\n")
  40. flag.Parse()
  41. dbAddr = *addr
  42. dbUser = *user
  43. dbPass = *pass
  44. dbName = strings.ToLower(*name)
  45. genTables = strings.ToLower(*table)
  46. }
  47. func main() {
  48. // 初始化 DB
  49. db, err := mysql.New(dbAddr, dbUser, dbPass, dbName)
  50. if err != nil {
  51. log.Fatal("new db err", err)
  52. }
  53. defer func() {
  54. if err := db.DbClose(); err != nil {
  55. log.Println("db close err", err)
  56. }
  57. }()
  58. tables, err := queryTables(db.GetDb(), dbName, genTables)
  59. if err != nil {
  60. log.Println("query tables of database err", err)
  61. return
  62. }
  63. for _, table := range tables {
  64. filepath := "./internal/repository/mysql/" + table.Name
  65. _ = os.Mkdir(filepath, 0766)
  66. fmt.Println("create dir : ", filepath)
  67. mdName := fmt.Sprintf("%s/gen_table.md", filepath)
  68. mdFile, err := os.OpenFile(mdName, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0766)
  69. if err != nil {
  70. fmt.Printf("markdown file error %v\n", err.Error())
  71. return
  72. }
  73. fmt.Println(" └── file : ", table.Name+"/gen_table.md")
  74. modelName := fmt.Sprintf("%s/gen_model.go", filepath)
  75. modelFile, err := os.OpenFile(modelName, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0766)
  76. if err != nil {
  77. fmt.Printf("create and open model file error %v\n", err.Error())
  78. return
  79. }
  80. fmt.Println(" └── file : ", table.Name+"/gen_model.go")
  81. modelContent := fmt.Sprintf("package %s\n", table.Name)
  82. modelContent += fmt.Sprintf(`import "time"`)
  83. modelContent += fmt.Sprintf("\n\n// %s %s \n", capitalize(table.Name), table.Comment.String)
  84. modelContent += fmt.Sprintf("//go:generate gormgen -structs %s -input . \n", capitalize(table.Name))
  85. modelContent += fmt.Sprintf("type %s struct {\n", capitalize(table.Name))
  86. tableContent := fmt.Sprintf("#### %s.%s \n", dbName, table.Name)
  87. if table.Comment.String != "" {
  88. tableContent += table.Comment.String + "\n"
  89. }
  90. tableContent += "\n" +
  91. "| 序号 | 名称 | 描述 | 类型 | 键 | 为空 | 额外 | 默认值 |\n" +
  92. "| :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: |\n"
  93. columnInfo, columnInfoErr := queryTableColumn(db.GetDb(), dbName, table.Name)
  94. if columnInfoErr != nil {
  95. continue
  96. }
  97. for _, info := range columnInfo {
  98. tableContent += fmt.Sprintf(
  99. "| %d | %s | %s | %s | %s | %s | %s | %s |\n",
  100. info.OrdinalPosition,
  101. info.ColumnName,
  102. strings.ReplaceAll(strings.ReplaceAll(info.ColumnComment.String, "|", "\\|"), "\n", ""),
  103. info.ColumnType,
  104. info.ColumnKey.String,
  105. info.IsNullable,
  106. info.Extra.String,
  107. info.ColumnDefault.String,
  108. )
  109. if textType(info.DataType) == "time.Time" {
  110. modelContent += fmt.Sprintf("%s %s `%s` // %s\n", capitalize(info.ColumnName), textType(info.DataType), "gorm:\"time\"", info.ColumnComment.String)
  111. } else {
  112. modelContent += fmt.Sprintf("%s %s // %s\n", capitalize(info.ColumnName), textType(info.DataType), info.ColumnComment.String)
  113. }
  114. }
  115. mdFile.WriteString(tableContent)
  116. mdFile.Close()
  117. modelContent += "}\n"
  118. modelFile.WriteString(modelContent)
  119. modelFile.Close()
  120. }
  121. }
  122. func queryTables(db *gorm.DB, dbName string, tableName string) ([]tableInfo, error) {
  123. var tableCollect []tableInfo
  124. var tableArray []string
  125. var commentArray []sql.NullString
  126. sqlTables := fmt.Sprintf("SELECT `table_name`,`table_comment` FROM `information_schema`.`tables` WHERE `table_schema`= '%s'", dbName)
  127. rows, err := db.Raw(sqlTables).Rows()
  128. if err != nil {
  129. return tableCollect, err
  130. }
  131. defer rows.Close()
  132. for rows.Next() {
  133. var info tableInfo
  134. err = rows.Scan(&info.Name, &info.Comment)
  135. if err != nil {
  136. fmt.Printf("execute query tables action error,had ignored, detail is [%v]\n", err.Error())
  137. continue
  138. }
  139. tableCollect = append(tableCollect, info)
  140. tableArray = append(tableArray, info.Name)
  141. commentArray = append(commentArray, info.Comment)
  142. }
  143. // filter tables when specified tables params
  144. if tableName != "*" {
  145. tableCollect = nil
  146. chooseTables := strings.Split(tableName, ",")
  147. indexMap := make(map[int]int)
  148. for _, item := range chooseTables {
  149. subIndexMap := getTargetIndexMap(tableArray, item)
  150. for k, v := range subIndexMap {
  151. if _, ok := indexMap[k]; ok {
  152. continue
  153. }
  154. indexMap[k] = v
  155. }
  156. }
  157. if len(indexMap) != 0 {
  158. for _, v := range indexMap {
  159. var info tableInfo
  160. info.Name = tableArray[v]
  161. info.Comment = commentArray[v]
  162. tableCollect = append(tableCollect, info)
  163. }
  164. }
  165. }
  166. return tableCollect, err
  167. }
  168. func queryTableColumn(db *gorm.DB, dbName string, tableName string) ([]tableColumn, error) {
  169. // 定义承载列信息的切片
  170. var columns []tableColumn
  171. sqlTableColumn := fmt.Sprintf("SELECT `ORDINAL_POSITION`,`COLUMN_NAME`,`COLUMN_TYPE`,`DATA_TYPE`,`COLUMN_KEY`,`IS_NULLABLE`,`EXTRA`,`COLUMN_COMMENT`,`COLUMN_DEFAULT` FROM `information_schema`.`columns` WHERE `table_schema`= '%s' AND `table_name`= '%s' ORDER BY `ORDINAL_POSITION` ASC",
  172. dbName, tableName)
  173. rows, err := db.Raw(sqlTableColumn).Rows()
  174. if err != nil {
  175. fmt.Printf("execute query table column action error, detail is [%v]\n", err.Error())
  176. return columns, err
  177. }
  178. defer rows.Close()
  179. for rows.Next() {
  180. var column tableColumn
  181. err = rows.Scan(
  182. &column.OrdinalPosition,
  183. &column.ColumnName,
  184. &column.ColumnType,
  185. &column.DataType,
  186. &column.ColumnKey,
  187. &column.IsNullable,
  188. &column.Extra,
  189. &column.ColumnComment,
  190. &column.ColumnDefault)
  191. if err != nil {
  192. fmt.Printf("query table column scan error, detail is [%v]\n", err.Error())
  193. return columns, err
  194. }
  195. columns = append(columns, column)
  196. }
  197. return columns, err
  198. }
  199. func getTargetIndexMap(tableNameArr []string, item string) map[int]int {
  200. indexMap := make(map[int]int)
  201. for i := 0; i < len(tableNameArr); i++ {
  202. if tableNameArr[i] == item {
  203. if _, ok := indexMap[i]; ok {
  204. continue
  205. }
  206. indexMap[i] = i
  207. }
  208. }
  209. return indexMap
  210. }
  211. func capitalize(s string) string {
  212. var upperStr string
  213. chars := strings.Split(s, "_")
  214. for _, val := range chars {
  215. vv := []rune(val)
  216. for i := 0; i < len(vv); i++ {
  217. if i == 0 {
  218. if vv[i] >= 97 && vv[i] <= 122 {
  219. vv[i] -= 32
  220. }
  221. upperStr += string(vv[i])
  222. } else {
  223. upperStr += string(vv[i])
  224. }
  225. }
  226. }
  227. return upperStr
  228. }
  229. func textType(s string) string {
  230. var mysqlTypeToGoType = map[string]string{
  231. "tinyint": "int32",
  232. "smallint": "int32",
  233. "mediumint": "int32",
  234. "int": "int32",
  235. "integer": "int64",
  236. "bigint": "int64",
  237. "float": "float64",
  238. "double": "float64",
  239. "decimal": "float64",
  240. "date": "string",
  241. "time": "string",
  242. "year": "string",
  243. "datetime": "time.Time",
  244. "timestamp": "time.Time",
  245. "char": "string",
  246. "varchar": "string",
  247. "tinyblob": "string",
  248. "tinytext": "string",
  249. "blob": "string",
  250. "text": "string",
  251. "mediumblob": "string",
  252. "mediumtext": "string",
  253. "longblob": "string",
  254. "longtext": "string",
  255. }
  256. return mysqlTypeToGoType[s]
  257. }