-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmigrator.go
76 lines (70 loc) · 1.8 KB
/
migrator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package rawsql
import (
"database/sql"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/migrator"
)
type Migrator struct {
migrator.Migrator
Dialector
}
func (m Migrator) TableType(value interface{}) (tableType gorm.TableType, err error) {
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
var (
schema, tableName = m.CurrentSchema(stmt, stmt.Table)
)
table, ok := m.tables[tableName]
if ok && table != nil {
tableType = &migrator.TableType{
SchemaValue: schema,
NameValue: tableName,
CommentValue: sql.NullString{String: table.Comment, Valid: true},
}
}
return nil
})
return tableType, err
}
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
var (
_, tableName = m.CurrentSchema(stmt, stmt.Table)
)
table, ok := m.tables[tableName]
if ok && table != nil {
columnTypes = table.ColumnTypes
}
return nil
})
return columnTypes, err
}
func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
indexes := make([]gorm.Index, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
var (
_, tableName = m.CurrentSchema(stmt, stmt.Table)
)
table, ok := m.tables[tableName]
if ok && table != nil {
indexes = table.Indexes
}
return nil
})
return indexes, err
}
func (m Migrator) GetTables() (tableList []string, err error) {
tableList = make([]string, 0, len(m.tables))
for tb, _ := range m.tables {
tableList = append(tableList, tb)
}
return tableList, nil
}
func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (string, string) {
if tables := strings.Split(table, `.`); len(tables) == 2 {
return tables[0], tables[1]
}
m.DB = m.DB.Table(table)
return "", table
}