Skip to content

Commit

Permalink
Simplify the dbRow interface
Browse files Browse the repository at this point in the history
This also removes the unused fromRow.

Signed-off-by: mprahl <[email protected]>
(cherry picked from commit 094bd2b)
  • Loading branch information
mprahl authored and magic-mirror-bot[bot] committed Jan 5, 2024
1 parent 4230160 commit 06e3417
Showing 1 changed file with 1 addition and 104 deletions.
105 changes: 1 addition & 104 deletions controllers/complianceeventsapi/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ var (
)

type dbRow interface {
FromRow(*sql.Rows) error
GetOrCreate(ctx context.Context, db *sql.DB) error
InsertQuery() (string, []any)
SelectQuery(returnedColumns ...string) (string, []any)
}
Expand Down Expand Up @@ -108,10 +106,6 @@ func (c Cluster) Validate() error {
return errors.Join(errs...)
}

func (c *Cluster) FromRow(rows *sql.Rows) error {
return fromRow(rows, c)
}

func (c *Cluster) InsertQuery() (string, []any) {
sql := `INSERT INTO clusters (cluster_id, name) VALUES ($1, $2)`
values := []any{c.ClusterID, c.Name}
Expand Down Expand Up @@ -174,10 +168,6 @@ func (e EventDetails) Validate() error {
return errors.Join(errs...)
}

func (e *EventDetails) FromRow(rows *sql.Rows) error {
return fromRow(rows, e)
}

func (e *EventDetails) InsertQuery() (string, []any) {
sql := `INSERT INTO compliance_events` +
`(cluster_id, compliance, message, metadata, parent_policy_id, policy_id, reported_by, timestamp) ` +
Expand Down Expand Up @@ -212,10 +202,6 @@ func (p ParentPolicy) Validate() error {
return errors.Join(errs...)
}

func (p *ParentPolicy) FromRow(rows *sql.Rows) error {
return fromRow(rows, p)
}

func (p *ParentPolicy) InsertQuery() (string, []any) {
sql := `INSERT INTO parent_policies` +
`(categories, controls, name, namespace, standards) ` +
Expand Down Expand Up @@ -297,10 +283,6 @@ func (p *Policy) Validate() error {
return errors.Join(errs...)
}

func (p *Policy) FromRow(rows *sql.Rows) error {
return fromRow(rows, p)
}

func (p *Policy) InsertQuery() (string, []any) {
sql := `INSERT INTO policies` +
`(api_group, kind, name, namespace, severity, spec, spec_hash)` +
Expand Down Expand Up @@ -386,95 +368,11 @@ func (j *JSONMap) Scan(src interface{}) error {
return json.Unmarshal(source, j)
}

type dbField struct {
// value is the underlying value set on the struct field.
value any
// fieldIndex is the index used to access this struct field on the struct instance this field is a part of. This is
// useful to set the value on the struct after a database query.
fieldIndex int
// goType is the Go type of this struct field. This is used to marhsall a SELECT query result to a struct.
goType reflect.Type
}

// getFields parses the input object and returns a map where the keys are database column names and the values are
// dbField instances representing the struct fields on the input object.
func getFields(obj any) map[string]dbField {
fields := map[string]dbField{}

values := reflect.Indirect(reflect.ValueOf(obj))
typesOf := values.Type()

for i := 0; i < values.NumField(); i++ {
structField := typesOf.Field(i)

fieldName := structField.Tag.Get("db")
if fieldName == "" {
// Skip struct fields that are not database columns.
continue
}

reflectField := values.Field(i)

field := dbField{
fieldIndex: i,
goType: structField.Type,
}

if reflectField.IsZero() {
field.value = nil
} else {
field.value = reflect.Indirect(reflectField).Interface()
}

fields[fieldName] = field
}

return fields
}

// fromRow assigns the output from `rows.Scan` to the input `obj`. Note that this does not explicitly support
// nested structs.
func fromRow(rows *sql.Rows, obj any) error {
columns, err := rows.Columns()
if err != nil {
return err
}

objFields := getFields(obj)

values := reflect.Indirect(reflect.ValueOf(obj))

// scans is a slice of all the struct fields that will receive a value from the row
scans := make([]any, 0, len(columns))

for _, column := range columns {
if _, ok := objFields[column]; !ok {
panic("The column is not defined on the struct: " + column)
}

scans = append(scans, reflect.New(objFields[column].goType).Interface())
}

err = rows.Scan(scans...)
if err != nil {
return err
}

for i, column := range columns {
field := values.Field(objFields[column].fieldIndex)
src := reflect.Indirect(reflect.ValueOf(scans[i]))
field.Set(src)
}

return nil
}

// getOrCreate will translate the input object to an INSERT SQL query. When the input object already exists in the
// database, a SELECT query is performed. The primary key is set on the input object when it is inserted or gotten
// from the database. The INSERT first then SELECT approach is a clean way to account for race conditions of multiple
// goroutines creating the same row.
func getOrCreate(ctx context.Context, db *sql.DB, obj dbRow) error {
dbFields := getFields(obj)
insertQuery, insertArgs := obj.InsertQuery()

// On inserts, it returns the primary key value (e.g. id). If it already exists, nothing is returned.
Expand Down Expand Up @@ -507,8 +405,7 @@ func getOrCreate(ctx context.Context, db *sql.DB, obj dbRow) error {

// Set the primary key value on the object
values := reflect.Indirect(reflect.ValueOf(obj))
field := values.Field(dbFields["id"].fieldIndex)
field.Set(reflect.ValueOf(primaryKey))
values.FieldByName("KeyID").Set(reflect.ValueOf(primaryKey))

return nil
}

0 comments on commit 06e3417

Please sign in to comment.