Skip to content

Commit

Permalink
Update and enhance DDL for better database handling
Browse files Browse the repository at this point in the history
Extended the database handling capabilities with several key enhancements to the DDL. This includes adding new variables for different data types, modifying default behavior for specific column names, and dynamically adjusting field types. These changes result in better adaptability to handle diverse database schemas and improve the robustness of query generation.
  • Loading branch information
iesreza committed Feb 5, 2024
1 parent bd221c4 commit 89e9262
Showing 1 changed file with 65 additions and 19 deletions.
84 changes: 65 additions & 19 deletions lib/db/schema/ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@ import (
"strings"
)

var Engine = "mysql"

var EngineDataTypes = map[string]map[string]string{
"mariadb": {"json": "longtext"},
}

var InternalFunctions = [][]string{
{"CURRENT_TIMESTAMP", "CURRENT_TIMESTAMP()", "current_timestamp()", "current_timestamp", "NOW()", "now()", "CURRENT_DATE", "CURRENT_DATE()", "current_date", "current_date()"},
{"NULL", "null"},
}

var (
Expand Down Expand Up @@ -47,6 +54,7 @@ type Column struct {
OnUpdate string
Collate string
ForeignKey string
After string
}

type Columns []Column
Expand Down Expand Up @@ -163,12 +171,12 @@ func FromStatement(stmt *gorm.Statement) Table {
}
if column.Name == "created_at" {
column.Nullable = false
column.Default = "CURRENT_TIMESTAMP"
column.Default = "CURRENT_TIMESTAMP()"
}
if column.Name == "updated_at" {
column.Nullable = false
column.Default = "CURRENT_TIMESTAMP"
column.OnUpdate = "CURRENT_TIMESTAMP"
column.Default = "CURRENT_TIMESTAMP()"
column.OnUpdate = "CURRENT_TIMESTAMP()"
}

if column.Unique {
Expand Down Expand Up @@ -254,16 +262,27 @@ func (table Table) GetCreateQuery() []string {

func getFieldQuery(field *Column) string {
var query = quote(field.Name)
query += " " + field.Type
query += " " + fieldType(field.Type)
if field.AutoIncrement {
query += " AUTO_INCREMENT"
}

if field.Default != "" {
var v = field.Default
if field.Default != "" {
if !strings.HasSuffix(field.Default, "()") {
v = strconv.Quote(v)
if (strings.ToLower(field.Type) == "timestamp" || strings.ToLower(field.Type) == "datetime") && !field.Nullable && field.Default == "" {
v = strconv.Quote("0000-00-00 00:00:00")
} else {
var needQuote = true
for _, fns := range InternalFunctions {
if slices.Contains(fns, field.Default) {
needQuote = false
break
}
}
if needQuote {
v = strconv.Quote(v)
}
}
}
query += " DEFAULT " + v
Expand Down Expand Up @@ -318,12 +337,19 @@ func (local Table) GetDiff(remote table.Table) []string {
}

if r := remote.Columns.GetColumn(field.Name); r == nil {
var position = ""
if idx > 0 {
position = " AFTER " + quote(local.Columns[idx-1].Name)
}
queries = append(queries, fmt.Sprintf("-- column %s does not exists", field.Name))
queries = append(queries, "ALTER TABLE "+quote(local.Name)+" ADD "+getFieldQuery(&field)+";")
queries = append(queries, "ALTER TABLE "+quote(local.Name)+" ADD "+getFieldQuery(&field)+position+";")
} else {
var diff = false
if strings.ToLower(field.Type) != strings.ToLower(strings.ToLower(r.ColumnType)) {
queries = append(queries, fmt.Sprintf("-- type does not match. new:%s old:%s", field.Type, strings.ToLower(r.ColumnType)))
if idx > 0 && idx < len(remote.Columns) && remote.Columns[idx].Name != field.Name {
diff = true
}
if fieldType(strings.ToLower(field.Type)) != fieldType(strings.ToLower(r.ColumnType)) {
queries = append(queries, fmt.Sprintf("-- type does not match. new:%s old:%s", fieldType(field.Type), strings.ToLower(r.ColumnType)))
diff = true
}
if len(field.Collate) > 0 && strings.ToLower(field.Collate) != strings.ToLower(r.Collation) {
Expand All @@ -345,6 +371,9 @@ func (local Table) GetDiff(remote table.Table) []string {
skip = true
}
}
if field.Default == "" && getString(r.ColumnDefault) == "0000-00-00 00:00:00" {
skip = true
}

if !skip && !(field.Default == "NULL" && r.ColumnDefault == nil) {
queries = append(queries, fmt.Sprintf("-- default value does not match. new:%s old:%s", field.Default, getString(r.ColumnDefault)))
Expand All @@ -362,10 +391,14 @@ func (local Table) GetDiff(remote table.Table) []string {
needPK = true
}
if diff {
var position = ""
if idx > 0 {
position = " AFTER " + quote(local.Columns[idx-1].Name)
}
if needPK {
afterPK = append(afterPK, "ALTER TABLE "+quote(local.Name)+" MODIFY COLUMN "+getFieldQuery(&field)+";")
afterPK = append(afterPK, "ALTER TABLE "+quote(local.Name)+" MODIFY COLUMN "+getFieldQuery(&field)+position+";")
} else {
queries = append(queries, "ALTER TABLE "+quote(local.Name)+" MODIFY COLUMN "+getFieldQuery(&field)+";")
queries = append(queries, "ALTER TABLE "+quote(local.Name)+" MODIFY COLUMN "+getFieldQuery(&field)+position+";")
}

}
Expand Down Expand Up @@ -454,6 +487,15 @@ func (local Table) GetDiff(remote table.Table) []string {
return queries
}

func fieldType(t string) string {
if v, ok := EngineDataTypes[Engine]; ok {
if v, ok := v[t]; ok {
return v
}
}
return t
}

func getInt(v *int) int {
if v == nil {
return 0
Expand Down Expand Up @@ -494,7 +536,7 @@ func GetCollate(statement *gorm.Statement) string {
return DefaultCollation
}

func (local Table) Constrains(constraints []table.Constraint) []string {
func (local Table) Constrains(constraints []table.Constraint, is table.Tables) []string {
var queries []string
for idx, _ := range local.Columns {
var field = local.Columns[idx]
Expand All @@ -509,17 +551,20 @@ func (local Table) Constrains(constraints []table.Constraint) []string {
referencedCol = tb.PrimaryKey[0]
}
}*/
if len(chunks) == 2 {
if len(chunks) == 1 {
var tb = is.GetTable(chunks[0])
if tb != nil && len(tb.PrimaryKey) > 0 {
referencedTable = tb.Table
referencedCol = tb.PrimaryKey[0].Name
}
} else if len(chunks) == 2 {
referencedTable = chunks[0]
referencedCol = chunks[1]
/* if tb := schema.Find(chunks[0]); tb != nil {
referencedTable = tb.Table
referencedCol = tb.PrimaryKey[0]
}*/
}

if referencedTable != "" && referencedCol != "" {

var name = "fk_" + referencedTable + "_" + field.Name + "_" + referencedTable + "_" + referencedCol
var name = "fk_" + local.Name + "." + field.Name + "_" + referencedTable + "." + referencedCol

var skip = false
for _, constraint := range constraints {
Expand All @@ -531,7 +576,8 @@ func (local Table) Constrains(constraints []table.Constraint) []string {

if !skip {
queries = append(queries, "-- create foreign key")
queries = append(queries, "ALTER TABLE "+quote(local.Name)+" ADD CONSTRAINT "+quote(name)+" FOREIGN KEY ("+quote(field.Name)+") REFERENCES "+quote(referencedTable)+"("+quote(referencedCol)+") ON DELETE SET NULL ON UPDATE RESTRICT")
var onDelete = "RESTRICT"
queries = append(queries, "ALTER TABLE "+quote(local.Name)+" ADD CONSTRAINT "+quote(name)+" FOREIGN KEY ("+quote(field.Name)+") REFERENCES "+quote(referencedTable)+"("+quote(referencedCol)+") ON DELETE "+onDelete+" ON UPDATE RESTRICT")
}
/* ALTER TABLE `rabbits`
ADD CONSTRAINT `fk_rabbits_main_page` FOREIGN KEY IF NOT EXISTS
Expand Down

0 comments on commit 89e9262

Please sign in to comment.