diff --git a/migrator.go b/migrator.go index 09f9d83..460ee27 100644 --- a/migrator.go +++ b/migrator.go @@ -3,6 +3,7 @@ package dameng import ( "database/sql" "fmt" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -15,8 +16,38 @@ type Migrator struct { Dialector } +// AutoMigrate 自动迁移模型为表结构 +// +// // 迁移并设置单个表注释 +// db.Set("gorm:table_comments", "用户信息表").AutoMigrate(&User{}) +// +// // 迁移并设置多个表注释 +// db.Set("gorm:table_comments", []string{"用户信息表", "公司信息表"}).AutoMigrate(&User{}, &Company{}) func (m Migrator) AutoMigrate(dst ...interface{}) error { - return m.Migrator.AutoMigrate(dst...) + if err := m.Migrator.AutoMigrate(dst...); err != nil { + return err + } + if tableComments, ok := m.DB.Get("gorm:table_comments"); ok { + var comments []string + switch c := tableComments.(type) { + case string: + comments = []string{c} + case []string: + comments = c + default: + return nil + } + for i := 0; i < len(dst) && i < len(comments); i++ { + value := dst[i] + comment := strings.ReplaceAll(comments[i], "'", "''") + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec(fmt.Sprintf("COMMENT ON TABLE ? IS '%s'", comment), m.CurrentTable(stmt)).Error + }); err != nil { + return err + } + } + } + return nil } func (m Migrator) CurrentDatabase() (name string) {