diff --git a/expressions/options/options.go b/expressions/options/options.go index 17cc340b1..e82fb769e 100644 --- a/expressions/options/options.go +++ b/expressions/options/options.go @@ -51,16 +51,16 @@ var typeCompatibilityMapping = map[string][][]*types.Type{ {typing.Date, typing.Timestamp, types.TimestampType}, }, operators.Add: { - {types.NewListType(types.IntType), types.NewListType(types.DoubleType), typing.Number, typing.Decimal}, + {types.IntType, types.DoubleType, typing.Number, typing.Decimal}, }, operators.Subtract: { - {types.NewListType(types.IntType), types.NewListType(types.DoubleType), typing.Number, typing.Decimal}, + {types.IntType, types.DoubleType, typing.Number, typing.Decimal}, }, operators.Multiply: { - {types.NewListType(types.IntType), types.NewListType(types.DoubleType), typing.Number, typing.Decimal}, + {types.IntType, types.DoubleType, typing.Number, typing.Decimal}, }, operators.Divide: { - {types.NewListType(types.IntType), types.NewListType(types.DoubleType), typing.Number, typing.Decimal}, + {types.IntType, types.DoubleType, typing.Number, typing.Decimal}, }, } diff --git a/expressions/parser.go b/expressions/parser.go index 8a005c6a3..bee6ad771 100644 --- a/expressions/parser.go +++ b/expressions/parser.go @@ -163,6 +163,8 @@ func typesAssignable(expected *types.Type, actual *types.Type) bool { typing.Markdown.String(): {mapType(typing.Text.String()), mapType(typing.Markdown.String())}, typing.ID.String(): {mapType(typing.Text.String()), mapType(typing.ID.String())}, typing.Text.String(): {mapType(typing.Text.String()), mapType(typing.Markdown.String()), mapType(typing.ID.String())}, + typing.Number.String(): {mapType(typing.Number.String()), mapType(typing.Decimal.String())}, + typing.Decimal.String(): {mapType(typing.Number.String()), mapType(typing.Decimal.String())}, } // Check if there are specific compatibility rules for the expected type diff --git a/expressions/resolve/visitor.go b/expressions/resolve/visitor.go index 80bfe9668..00a1b69a8 100644 --- a/expressions/resolve/visitor.go +++ b/expressions/resolve/visitor.go @@ -74,19 +74,19 @@ func (w *CelVisitor[T]) run(expression *parser.Expression) (T, error) { w.ast = ast - if err := w.eval(checkedExpr.Expr, false); err != nil { + if err := w.eval(checkedExpr.Expr, isComplexOperatorWithRespectTo(operators.LogicalAnd, checkedExpr.Expr), false); err != nil { return zero, err } return w.visitor.Result() } -func (w *CelVisitor[T]) eval(expr *exprpb.Expr, inBinaryCondition bool) error { +func (w *CelVisitor[T]) eval(expr *exprpb.Expr, nested bool, inBinary bool) error { var err error switch expr.ExprKind.(type) { case *exprpb.Expr_ConstExpr, *exprpb.Expr_ListExpr, *exprpb.Expr_SelectExpr, *exprpb.Expr_IdentExpr: - if !inBinaryCondition { + if !inBinary { err := w.visitor.StartCondition(false) if err != nil { return err @@ -96,10 +96,20 @@ func (w *CelVisitor[T]) eval(expr *exprpb.Expr, inBinaryCondition bool) error { switch expr.ExprKind.(type) { case *exprpb.Expr_CallExpr: + err = w.visitor.StartCondition(nested) + if err != nil { + return err + } + err := w.callExpr(expr) if err != nil { return err } + + err = w.visitor.EndCondition(nested) + if err != nil { + return err + } case *exprpb.Expr_ConstExpr: err := w.constExpr(expr) if err != nil { @@ -126,7 +136,7 @@ func (w *CelVisitor[T]) eval(expr *exprpb.Expr, inBinaryCondition bool) error { switch expr.ExprKind.(type) { case *exprpb.Expr_ConstExpr, *exprpb.Expr_ListExpr, *exprpb.Expr_SelectExpr, *exprpb.Expr_IdentExpr: - if !inBinaryCondition { + if !inBinary { err := w.visitor.EndCondition(false) if err != nil { return err @@ -173,18 +183,12 @@ func (w *CelVisitor[T]) binaryCall(expr *exprpb.Expr) error { op := c.GetFunction() args := c.GetArgs() lhs := args[0] - - isComplex := isComplexOperatorWithRespectTo(operators.LogicalAnd, expr) - - err := w.visitor.StartCondition(isComplex) - if err != nil { - return err - } + lhsParen := isComplexOperatorWithRespectTo(op, lhs) + var err error inBinary := !(op == operators.LogicalAnd || op == operators.LogicalOr) - rhs := args[1] - if err := w.eval(lhs, inBinary); err != nil { + if err := w.eval(lhs, lhsParen, inBinary); err != nil { return err } @@ -200,11 +204,17 @@ func (w *CelVisitor[T]) binaryCall(expr *exprpb.Expr) error { return err } - if err := w.eval(rhs, inBinary); err != nil { + rhs := args[1] + rhsParen := isComplexOperatorWithRespectTo(op, rhs) + if !rhsParen && isLeftRecursive(op) { + rhsParen = isSamePrecedence(op, rhs) + } + + if err := w.eval(rhs, rhsParen, inBinary); err != nil { return err } - return w.visitor.EndCondition(isComplex) + return nil } func (w *CelVisitor[T]) unaryCall(expr *exprpb.Expr) error { @@ -224,16 +234,11 @@ func (w *CelVisitor[T]) unaryCall(expr *exprpb.Expr) error { return fmt.Errorf("not implemented: %s", fun) } - err := w.visitor.StartCondition(isComplex) - if err != nil { - return err - } - - if err := w.eval(args[0], false); err != nil { + if err := w.eval(args[0], isComplex, false); err != nil { return err } - return w.visitor.EndCondition(isComplex) + return nil } func (w *CelVisitor[T]) constExpr(expr *exprpb.Expr) error { @@ -362,7 +367,7 @@ func (w *CelVisitor[T]) SelectExpr(expr *exprpb.Expr) error { switch expr.ExprKind.(type) { case *exprpb.Expr_CallExpr: - err := w.eval(sel.GetOperand(), true) + err := w.eval(sel.GetOperand(), true, true) if err != nil { return err } @@ -449,6 +454,25 @@ func isComplexOperatorWithRespectTo(op string, expr *exprpb.Expr) bool { return isLowerPrecedence(op, expr) } +// isLeftRecursive indicates whether the parser resolves the call in a left-recursive manner as +// this can have an effect of how parentheses affect the order of operations in the AST. +func isLeftRecursive(op string) bool { + return op != operators.LogicalAnd && op != operators.LogicalOr +} + +// isSamePrecedence indicates whether the precedence of the input operator is the same as the +// precedence of the (possible) operation represented in the input Expr. +// +// If the expr is not a Call, the result is false. +func isSamePrecedence(op string, expr *exprpb.Expr) bool { + if expr.GetCallExpr() == nil { + return false + } + c := expr.GetCallExpr() + other := c.GetFunction() + return operators.Precedence(op) == operators.Precedence(other) +} + func toNative(c *exprpb.Constant) (any, error) { switch c.ConstantKind.(type) { case *exprpb.Constant_BoolValue: diff --git a/go.mod b/go.mod index aa8dff63b..842ff5bc5 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,7 @@ require ( github.com/spf13/viper v1.15.0 github.com/stretchr/testify v1.8.4 github.com/teamkeel/graphql v0.8.2-0.20230531102419-995b8ab035b6 + github.com/test-go/testify v1.1.4 github.com/twitchtv/twirp v8.1.3+incompatible github.com/vincent-petithory/dataurl v1.0.0 github.com/xeipuuv/gojsonschema v1.2.0 diff --git a/go.sum b/go.sum index 9fbeda544..822b8c086 100644 --- a/go.sum +++ b/go.sum @@ -403,6 +403,8 @@ github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8 github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= github.com/teamkeel/graphql v0.8.2-0.20230531102419-995b8ab035b6 h1:q8ZbAgqr7jJlZNJ4WAI+QMuZrcCBDOw9k7orYuy+Vqs= github.com/teamkeel/graphql v0.8.2-0.20230531102419-995b8ab035b6/go.mod h1:5td34OA5ZUdckc2w3GgE7QQoaG8MK6hIVR3dFI+qaK4= +github.com/test-go/testify v1.1.4 h1:Tf9lntrKUMHiXQ07qBScBTSA0dhYQlu83hswqelv1iE= +github.com/test-go/testify v1.1.4/go.mod h1:rH7cfJo/47vWGdi4GPj16x3/t1xGOj2YxzmNQzk2ghU= github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M= github.com/thoas/go-funk v0.9.1/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/tkuchiki/go-timezone v0.2.0 h1:yyZVHtQRVZ+wvlte5HXvSpBkR0dPYnPEIgq9qqAqltk= diff --git a/integration/testdata/computed_fields/schema.keel b/integration/testdata/computed_fields/schema.keel new file mode 100644 index 000000000..c1f949113 --- /dev/null +++ b/integration/testdata/computed_fields/schema.keel @@ -0,0 +1,46 @@ +model ComputedDecimal { + fields { + price Decimal + quantity Number + total Decimal @computed(computedDecimal.quantity * computedDecimal.price) + totalWithShipping Decimal @computed(5 + computedDecimal.quantity * computedDecimal.price) + totalWithDiscount Decimal @computed(computedDecimal.quantity * (computedDecimal.price - (computedDecimal.price / 100 * 10))) + } +} + +model ComputedNumber { + fields { + price Decimal + quantity Number + total Number @computed(computedNumber.quantity * computedNumber.price) + totalWithShipping Number @computed(5 + computedNumber.quantity * computedNumber.price) + totalWithDiscount Number @computed(computedNumber.quantity * (computedNumber.price - (computedNumber.price / 100 * 10))) + } +} + +model ComputedBool { + fields { + price Decimal? + isActive Boolean + isExpensive Boolean @computed(computedBool.price > 100 && computedBool.isActive) + isCheap Boolean @computed(!computedBool.isExpensive) + } +} + +model ComputedNulls { + fields { + price Decimal? + quantity Number? + total Decimal? @computed(computedNulls.quantity * computedNulls.price) + } +} + +model ComputedDepends { + fields { + price Decimal + quantity Number + totalWithDiscount Decimal? @computed(computedDepends.totalWithShipping - (computedDepends.totalWithShipping / 100 * 10)) + totalWithShipping Decimal? @computed(computedDepends.total + 5) + total Decimal? @computed(computedDepends.quantity * computedDepends.price) + } +} diff --git a/integration/testdata/computed_fields/tests.test.ts b/integration/testdata/computed_fields/tests.test.ts new file mode 100644 index 000000000..0c75a00d2 --- /dev/null +++ b/integration/testdata/computed_fields/tests.test.ts @@ -0,0 +1,139 @@ +import { test, expect, beforeEach } from "vitest"; +import { models, resetDatabase } from "@teamkeel/testing"; + +beforeEach(resetDatabase); + +test("computed fields - decimal", async () => { + const item = await models.computedDecimal.create({ price: 5, quantity: 2 }); + expect(item.total).toEqual(10); + expect(item.totalWithShipping).toEqual(15); + expect(item.totalWithDiscount).toEqual(9); + + const get = await models.computedDecimal.findOne({ id: item.id }); + expect(get!.total).toEqual(10); + expect(get!.totalWithShipping).toEqual(15); + expect(get!.totalWithDiscount).toEqual(9); + + const updatePrice = await models.computedDecimal.update( + { id: item.id }, + { price: 10 } + ); + expect(updatePrice.total).toEqual(20); + expect(updatePrice.totalWithShipping).toEqual(25); + expect(updatePrice.totalWithDiscount).toEqual(18); + + const updateQuantity = await models.computedDecimal.update( + { id: item.id }, + { quantity: 3 } + ); + expect(updateQuantity.total).toEqual(30); + expect(updateQuantity.totalWithShipping).toEqual(35); + expect(updateQuantity.totalWithDiscount).toEqual(27); + + const updateBoth = await models.computedDecimal.update( + { id: item.id }, + { price: 12, quantity: 4 } + ); + expect(updateBoth.total).toEqual(48); + expect(updateBoth.totalWithShipping).toEqual(53); + expect(updateBoth.totalWithDiscount).toEqual(43.2); +}); + +test("computed fields - number", async () => { + const item = await models.computedNumber.create({ price: 5, quantity: 2 }); + expect(item.total).toEqual(10); + expect(item.totalWithShipping).toEqual(15); + expect(item.totalWithDiscount).toEqual(9); + + const get = await models.computedNumber.findOne({ id: item.id }); + expect(get!.total).toEqual(10); + expect(get!.totalWithShipping).toEqual(15); + expect(get!.totalWithDiscount).toEqual(9); + + const updatePrice = await models.computedNumber.update( + { id: item.id }, + { price: 10 } + ); + expect(updatePrice.total).toEqual(20); + expect(updatePrice.totalWithShipping).toEqual(25); + expect(updatePrice.totalWithDiscount).toEqual(18); + + const updateQuantity = await models.computedNumber.update( + { id: item.id }, + { quantity: 3 } + ); + expect(updateQuantity.total).toEqual(30); + expect(updateQuantity.totalWithShipping).toEqual(35); + expect(updateQuantity.totalWithDiscount).toEqual(27); + + const updateBoth = await models.computedNumber.update( + { id: item.id }, + { price: 12, quantity: 4 } + ); + expect(updateBoth.total).toEqual(48); + expect(updateBoth.totalWithShipping).toEqual(53); + expect(updateBoth.totalWithDiscount).toEqual(43); +}); + +test("computed fields - boolean", async () => { + const expensive = await models.computedBool.create({ + price: 200, + isActive: true, + }); + expect(expensive.isExpensive).toBeTruthy(); + expect(expensive.isCheap).toBeFalsy(); + + const notExpensive = await models.computedBool.create({ + price: 90, + isActive: true, + }); + expect(notExpensive.isExpensive).toBeFalsy(); + expect(notExpensive.isCheap).toBeTruthy(); + + const notActive = await models.computedBool.create({ + price: 200, + isActive: false, + }); + expect(notActive.isExpensive).toBeFalsy(); + expect(notActive.isCheap).toBeTruthy(); +}); + +test("computed fields - with nulls", async () => { + const item = await models.computedNulls.create({ price: 5 }); + expect(item.total).toBeNull(); + + const updateQty = await models.computedNulls.update( + { id: item.id }, + { quantity: 10 } + ); + expect(updateQty!.total).toEqual(50); + + const updatePrice2 = await models.computedNulls.update( + { id: item.id }, + { price: null } + ); + expect(updatePrice2!.total).toBeNull(); +}); + +test("computed fields - with dependencies", async () => { + const item = await models.computedDepends.create({ price: 5, quantity: 2 }); + expect(item.total).toEqual(10); + expect(item.totalWithShipping).toEqual(15); + expect(item.totalWithDiscount).toEqual(13.5); + + const updatedQty = await models.computedDepends.update( + { id: item.id }, + { quantity: 10 } + ); + expect(updatedQty.total).toEqual(50); + expect(updatedQty.totalWithShipping).toEqual(55); + expect(updatedQty.totalWithDiscount).toEqual(49.5); + + const updatePrice = await models.computedDepends.update( + { id: item.id }, + { price: 8 } + ); + expect(updatePrice.total).toEqual(80); + expect(updatePrice.totalWithShipping).toEqual(85); + expect(updatePrice.totalWithDiscount).toEqual(76.5); +}); diff --git a/migrations/computed_functions.sql b/migrations/computed_functions.sql new file mode 100644 index 000000000..aaef2371b --- /dev/null +++ b/migrations/computed_functions.sql @@ -0,0 +1,8 @@ +SELECT + routine_name +FROM + information_schema.routines +WHERE + routine_type = 'FUNCTION' +AND + routine_schema = 'public' AND routine_name LIKE '%__computed'; \ No newline at end of file diff --git a/migrations/introspection.go b/migrations/introspection.go index 8e006d6f1..885faa8f4 100644 --- a/migrations/introspection.go +++ b/migrations/introspection.go @@ -22,6 +22,11 @@ func getColumns(database db.Database) ([]*ColumnRow, error) { return rows, database.GetDB().Raw(columnsQuery).Scan(&rows).Error } +func getComputedFunctions(database db.Database) ([]*FunctionRow, error) { + rows := []*FunctionRow{} + return rows, database.GetDB().Raw(computedFunctionsQuery).Scan(&rows).Error +} + var ( //go:embed columns.sql columnsQuery string @@ -31,6 +36,9 @@ var ( //go:embed triggers.sql triggersQuery string + + //go:embed computed_functions.sql + computedFunctionsQuery string ) type ColumnRow struct { @@ -80,3 +88,7 @@ type TriggerRow struct { // e.g. AFTER ActionTiming string `json:"action_timing"` } + +type FunctionRow struct { + RoutineName string `json:"routine_name"` +} diff --git a/migrations/migrations.go b/migrations/migrations.go index ce17547e1..ae6f902c4 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -5,6 +5,7 @@ import ( _ "embed" "errors" "fmt" + "slices" "strings" "github.com/iancoleman/strcase" @@ -12,7 +13,9 @@ import ( "github.com/teamkeel/keel/auditing" "github.com/teamkeel/keel/casing" "github.com/teamkeel/keel/db" + "github.com/teamkeel/keel/expressions/resolve" "github.com/teamkeel/keel/proto" + "github.com/teamkeel/keel/schema/parser" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "google.golang.org/protobuf/encoding/protojson" @@ -174,7 +177,7 @@ func New(ctx context.Context, schema *proto.Schema, database db.Database) (*Migr return nil, err } - triggers, err := getTriggers(database) + existingTriggers, err := getTriggers(database) if err != nil { return nil, err } @@ -251,10 +254,10 @@ func New(ctx context.Context, schema *proto.Schema, database db.Database) (*Migr // Add audit log triggers all model tables excluding the audit table itself. for _, model := range schema.Models { if model.Name != strcase.ToCamel(auditing.TableName) { - stmt := createAuditTriggerStmts(triggers, model) + stmt := createAuditTriggerStmts(existingTriggers, model) statements = append(statements, stmt) - stmt = createUpdatedAtTriggerStmts(triggers, model) + stmt = createUpdatedAtTriggerStmts(existingTriggers, model) statements = append(statements, stmt) } } @@ -361,6 +364,38 @@ func New(ctx context.Context, schema *proto.Schema, database db.Database) (*Migr } } + // Fetch all computed functions in the database + existingComputedFns, err := getComputedFunctions(database) + if err != nil { + return nil, err + } + + // Computed fields functions and triggers + computedChanges, stmts, err := computedFieldsStmts(schema, existingComputedFns) + if err != nil { + return nil, err + } + + for _, change := range computedChanges { + // Dont add the db change if the field was already modified elsewhere + if lo.ContainsBy(changes, func(c *DatabaseChange) bool { + return c.Model == change.Model && c.Field == change.Field + }) { + continue + } + + // Dont add the db change if the model is new + if lo.ContainsBy(changes, func(c *DatabaseChange) bool { + return c.Model == change.Model && c.Field == "" && c.Type == ChangeTypeAdded + }) { + continue + } + + changes = append(changes, computedChanges...) + } + + statements = append(statements, stmts...) + stringChanges := lo.Map(changes, func(c *DatabaseChange, _ int) string { return c.String() }) span.SetAttributes(attribute.StringSlice("migration", stringChanges)) @@ -417,6 +452,187 @@ func compositeUniqueConstraints(schema *proto.Schema, model *proto.Model, constr return statements, nil } +// computedFieldDependencies returns a map of computed fields and every field it depends on +func computedFieldDependencies(schema *proto.Schema) (map[*proto.Field][]*proto.Field, error) { + dependencies := map[*proto.Field][]*proto.Field{} + + for _, model := range schema.Models { + for _, field := range model.Fields { + if field.ComputedExpression == nil { + continue + } + + expr, err := parser.ParseExpression(field.ComputedExpression.Source) + if err != nil { + return nil, err + } + + idents, err := resolve.IdentOperands(expr) + if err != nil { + return nil, err + } + + for _, ident := range idents { + for _, f := range schema.FindModel(strcase.ToCamel(ident.Fragments[0])).Fields { + if f.Name == ident.Fragments[1] { + dependencies[field] = append(dependencies[field], f) + break + } + } + } + } + } + + return dependencies, nil +} + +// computedFieldsStmts generates SQL statements for dropping or creating functions and triggers for computed fields +func computedFieldsStmts(schema *proto.Schema, existingComputedFns []*FunctionRow) (changes []*DatabaseChange, statements []string, err error) { + existingComputedFnNames := lo.Map(existingComputedFns, func(f *FunctionRow, _ int) string { + return f.RoutineName + }) + + fns := map[string]string{} + fieldsFns := map[*proto.Field]string{} + changedFields := map[*proto.Field]bool{} + + // Adding computed field triggers and functions + for _, model := range schema.Models { + modelFns := map[string]string{} + + for _, field := range model.GetComputedFields() { + changedFields[field] = false + fnName, computedFuncStmt, err := addComputedFieldFuncStmt(schema, model, field) + if err != nil { + return nil, nil, err + } + + fieldsFns[field] = fnName + modelFns[fnName] = computedFuncStmt + } + + // Get all the preexisting computed functions for computed fields on this model + existingComputedFnNamesForModel := lo.Filter(existingComputedFnNames, func(f string, _ int) bool { + return strings.HasPrefix(f, fmt.Sprintf("%s__", strcase.ToSnake(model.Name))) && + strings.HasSuffix(f, "_computed") + }) + + newFns, retiredFns := lo.Difference(lo.Keys(modelFns), existingComputedFnNamesForModel) + slices.Sort(newFns) + slices.Sort(retiredFns) + + // Functions to be created + for _, fn := range newFns { + statements = append(statements, modelFns[fn]) + + f := fieldFromComputedFnName(schema, fn) + changes = append(changes, &DatabaseChange{ + Model: f.ModelName, + Field: f.Name, + Type: ChangeTypeModified, + }) + changedFields[f] = true + } + + // Functions to be dropped + for _, fn := range retiredFns { + statements = append(statements, fmt.Sprintf("DROP FUNCTION %s;", fn)) + + f := fieldFromComputedFnName(schema, fn) + if f != nil { + change := &DatabaseChange{ + Model: f.ModelName, + Field: f.Name, + Type: ChangeTypeModified, + } + if !lo.ContainsBy(changes, func(c *DatabaseChange) bool { + return c.Model == change.Model && c.Field == change.Field + }) { + changes = append(changes, change) + } + changedFields[f] = true + } + } + + // When there all computed fields have been removed + if len(modelFns) == 0 && len(retiredFns) > 0 { + dropExecFn := dropComputedExecFunctionStmt(model) + dropTrigger := dropComputedTriggerStmt(model) + statements = append(statements, dropTrigger, dropExecFn) + } + + for k, v := range modelFns { + fns[k] = v + } + } + + dependencies, err := computedFieldDependencies(schema) + if err != nil { + return nil, nil, err + } + + for _, model := range schema.Models { + modelhasChanged := false + for k, v := range changedFields { + if k.ModelName == model.Name && v { + modelhasChanged = true + } + } + if !modelhasChanged { + continue + } + + computedFields := model.GetComputedFields() + if len(computedFields) == 0 { + continue + } + + // Sort fields based on dependencies + sorted := []*proto.Field{} + visited := map[*proto.Field]bool{} + var visit func(*proto.Field) + visit = func(field *proto.Field) { + if visited[field] || field.ComputedExpression == nil { + return + } + visited[field] = true + + // Process dependencies first + for _, dep := range dependencies[field] { + if dep.ModelName == field.ModelName { + visit(dep) + } + } + sorted = append(sorted, field) + } + + // Visit all fields to build sorted order + for _, field := range computedFields { + visit(field) + } + + // Generate SQL statements in dependency order + stmts := []string{} + for _, field := range sorted { + s := fmt.Sprintf("\tNEW.%s := %s(%s);\n", strcase.ToSnake(field.Name), fieldsFns[field], "NEW") + stmts = append(stmts, s) + } + + execFnName := computedExecFuncName(model) + triggerName := computedTriggerName(model) + + // Generate the trigger function which executes all the computed field functions for the model. + sql := fmt.Sprintf("CREATE OR REPLACE FUNCTION %s() RETURNS TRIGGER AS $$ BEGIN\n%s\tRETURN NEW;\nEND; $$ LANGUAGE plpgsql;", execFnName, strings.Join(stmts, "")) + + // Genrate the table trigger which executed the trigger function. + trigger := fmt.Sprintf("CREATE OR REPLACE TRIGGER %s BEFORE INSERT OR UPDATE ON \"%s\" FOR EACH ROW EXECUTE PROCEDURE %s();", triggerName, strcase.ToSnake(model.Name), execFnName) + + statements = append(statements, sql, trigger) + } + + return +} + func keelSchemaTableExists(ctx context.Context, database db.Database) (bool, error) { // to_regclass docs - https://www.postgresql.org/docs/current/functions-info.html#FUNCTIONS-INFO-CATALOG-TABLE // translates a textual relation name to its OID ... this function will diff --git a/migrations/migrations_test.go b/migrations/migrations_test.go index eb8eabbc5..dd19b9e11 100644 --- a/migrations/migrations_test.go +++ b/migrations/migrations_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "os" "path/filepath" "regexp" @@ -108,7 +109,11 @@ func TestMigrations(t *testing.T) { require.NoError(t, err) // Assert correct SQL generated - assert.Equal(t, expectedSQL, m.SQL) + equal := assert.Equal(t, expectedSQL, m.SQL) + + if !equal { + fmt.Println(m.SQL) + } actualChanges, err := json.Marshal(m.Changes) require.NoError(t, err) diff --git a/migrations/sql.go b/migrations/sql.go index a059480b6..335aacde0 100644 --- a/migrations/sql.go +++ b/migrations/sql.go @@ -1,6 +1,7 @@ package migrations import ( + "crypto/sha256" "fmt" "regexp" "strings" @@ -12,6 +13,7 @@ import ( "github.com/teamkeel/keel/db" "github.com/teamkeel/keel/expressions/resolve" "github.com/teamkeel/keel/proto" + "github.com/teamkeel/keel/runtime/actions" "github.com/teamkeel/keel/schema/parser" "golang.org/x/exp/slices" ) @@ -90,6 +92,7 @@ func createTableStmt(schema *proto.Schema, model *proto.Model) (string, error) { PrimaryKeyConstraintName(model.Name, field.Name), Identifier(field.Name))) } + if field.Unique && !field.PrimaryKey { uniqueStmt, err := addUniqueConstraintStmt(schema, model.Name, []string{field.Name}) if err != nil { @@ -228,6 +231,74 @@ func alterColumnStmt(modelName string, field *proto.Field, column *ColumnRow) (s return strings.Join(stmts, "\n"), nil } +// computedFieldFuncName generates the name of the a computed field's function +func computedFieldFuncName(model *proto.Model, field *proto.Field) string { + // shortened alphanumeric hash from an expression + hash := fmt.Sprintf("%x", sha256.Sum256([]byte(field.ComputedExpression.Source)))[:8] + return fmt.Sprintf("%s__%s__%s__computed", strcase.ToSnake(model.Name), strcase.ToSnake(field.Name), hash) +} + +// computedExecFuncName generates the name for the table function which executed all computed functions +func computedExecFuncName(model *proto.Model) string { + return fmt.Sprintf("%s__exec_computed_fns", strcase.ToSnake(model.Name)) +} + +// computedTriggerName generates the name for the trigger which runs the function which executes computed functions +func computedTriggerName(model *proto.Model) string { + return fmt.Sprintf("%s__computed_trigger", strcase.ToSnake(model.Name)) +} + +// fieldFromComputedFnName determines the field from computed function name +func fieldFromComputedFnName(schema *proto.Schema, fn string) *proto.Field { + parts := strings.Split(fn, "__") + model := schema.FindModel(strcase.ToCamel(parts[0])) + for _, f := range model.Fields { + if f.Name == strcase.ToLowerCamel(parts[1]) { + return f + } + } + return nil +} + +// addComputedFieldFuncStmt generates the function for a computed field +func addComputedFieldFuncStmt(schema *proto.Schema, model *proto.Model, field *proto.Field) (string, string, error) { + var sqlType string + switch field.Type.Type { + case proto.Type_TYPE_DECIMAL, proto.Type_TYPE_INT, proto.Type_TYPE_BOOL: + sqlType = PostgresFieldTypes[field.Type.Type] + default: + return "", "", fmt.Errorf("type not supported for computed fields: %s", field.Type.Type) + } + + expression, err := parser.ParseExpression(field.ComputedExpression.Source) + if err != nil { + return "", "", err + } + + // Generate SQL from the computed attribute expression to set this field + stmt, err := resolve.RunCelVisitor(expression, actions.GenerateComputedFunction(schema, model, field)) + if err != nil { + return "", "", err + } + + fn := computedFieldFuncName(model, field) + sql := fmt.Sprintf("CREATE FUNCTION %s(r %s) RETURNS %s AS $$ BEGIN\n\tRETURN %s;\nEND; $$ LANGUAGE plpgsql;", + fn, + strcase.ToSnake(model.Name), + sqlType, + stmt) + + return fn, sql, nil +} + +func dropComputedExecFunctionStmt(model *proto.Model) string { + return fmt.Sprintf("DROP FUNCTION %s__exec_computed_fns;", strcase.ToSnake(model.Name)) +} + +func dropComputedTriggerStmt(model *proto.Model) string { + return fmt.Sprintf("DROP TRIGGER %s__computed_trigger ON %s;", strcase.ToSnake(model.Name), strcase.ToSnake(model.Name)) +} + func fieldDefinition(field *proto.Field) (string, error) { columnName := Identifier(field.Name) diff --git a/migrations/testdata/computed_field_changed_expression.txt b/migrations/testdata/computed_field_changed_expression.txt new file mode 100644 index 000000000..7f60b7065 --- /dev/null +++ b/migrations/testdata/computed_field_changed_expression.txt @@ -0,0 +1,35 @@ +model Item { + fields { + price Decimal + quantity Number + total Decimal @computed(item.quantity * item.price) + } +} + +=== + +model Item { + fields { + price Decimal + quantity Number + total Decimal @computed(item.price + 5) + } +} + +=== + +CREATE FUNCTION item__total__863346d0__computed(r item) RETURNS NUMERIC AS $$ BEGIN + RETURN r."price" + 5; +END; $$ LANGUAGE plpgsql; +DROP FUNCTION item__total__0614a79a__computed; +CREATE OR REPLACE FUNCTION item__exec_computed_fns() RETURNS TRIGGER AS $$ BEGIN + NEW.total := item__total__863346d0__computed(NEW); + RETURN NEW; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE TRIGGER item__computed_trigger BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE item__exec_computed_fns(); + +=== + +[ + {"Model":"Item","Field":"total","Type":"MODIFIED"} +] diff --git a/migrations/testdata/computed_field_initial.txt b/migrations/testdata/computed_field_initial.txt new file mode 100644 index 000000000..5900302ee --- /dev/null +++ b/migrations/testdata/computed_field_initial.txt @@ -0,0 +1,78 @@ +=== + +model Item { + fields { + price Decimal + quantity Number + total Decimal @computed(item.quantity * item.price) + } +} + +=== + +CREATE TABLE "identity" ( +"email" TEXT, +"email_verified" BOOL NOT NULL DEFAULT false, +"password" TEXT, +"external_id" TEXT, +"issuer" TEXT, +"name" TEXT, +"given_name" TEXT, +"family_name" TEXT, +"middle_name" TEXT, +"nick_name" TEXT, +"profile" TEXT, +"picture" TEXT, +"website" TEXT, +"gender" TEXT, +"zone_info" TEXT, +"locale" TEXT, +"id" TEXT NOT NULL DEFAULT ksuid(), +"created_at" TIMESTAMPTZ NOT NULL DEFAULT now(), +"updated_at" TIMESTAMPTZ NOT NULL DEFAULT now() +); +ALTER TABLE "identity" ADD CONSTRAINT identity_id_pkey PRIMARY KEY ("id"); +ALTER TABLE "identity" ADD CONSTRAINT identity_email_issuer_udx UNIQUE ("email", "issuer"); +CREATE TABLE "item" ( +"price" NUMERIC NOT NULL, +"quantity" INTEGER NOT NULL, +"total" NUMERIC NOT NULL, +"id" TEXT NOT NULL DEFAULT ksuid(), +"created_at" TIMESTAMPTZ NOT NULL DEFAULT now(), +"updated_at" TIMESTAMPTZ NOT NULL DEFAULT now() +); +ALTER TABLE "item" ADD CONSTRAINT item_id_pkey PRIMARY KEY ("id"); +CREATE TABLE "keel_audit" ( +"id" TEXT NOT NULL DEFAULT ksuid(), +"table_name" TEXT NOT NULL, +"op" TEXT NOT NULL, +"data" jsonb NOT NULL, +"created_at" TIMESTAMPTZ NOT NULL DEFAULT now(), +"identity_id" TEXT, +"trace_id" TEXT, +"event_processed_at" TIMESTAMPTZ +); +ALTER TABLE "keel_audit" ADD CONSTRAINT keel_audit_id_pkey PRIMARY KEY ("id"); +CREATE TRIGGER item_create AFTER INSERT ON "item" REFERENCING NEW TABLE AS new_table FOR EACH STATEMENT EXECUTE PROCEDURE process_audit(); +CREATE TRIGGER item_update AFTER UPDATE ON "item" REFERENCING NEW TABLE AS new_table OLD TABLE AS old_table FOR EACH STATEMENT EXECUTE PROCEDURE process_audit(); +CREATE TRIGGER item_delete AFTER DELETE ON "item" REFERENCING OLD TABLE AS old_table FOR EACH STATEMENT EXECUTE PROCEDURE process_audit(); +CREATE TRIGGER item_updated_at BEFORE UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE set_updated_at(); +CREATE TRIGGER identity_create AFTER INSERT ON "identity" REFERENCING NEW TABLE AS new_table FOR EACH STATEMENT EXECUTE PROCEDURE process_audit(); +CREATE TRIGGER identity_update AFTER UPDATE ON "identity" REFERENCING NEW TABLE AS new_table OLD TABLE AS old_table FOR EACH STATEMENT EXECUTE PROCEDURE process_audit(); +CREATE TRIGGER identity_delete AFTER DELETE ON "identity" REFERENCING OLD TABLE AS old_table FOR EACH STATEMENT EXECUTE PROCEDURE process_audit(); +CREATE TRIGGER identity_updated_at BEFORE UPDATE ON "identity" FOR EACH ROW EXECUTE PROCEDURE set_updated_at(); +CREATE FUNCTION item__total__0614a79a__computed(r item) RETURNS NUMERIC AS $$ BEGIN + RETURN r."quantity" * r."price"; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE FUNCTION item__exec_computed_fns() RETURNS TRIGGER AS $$ BEGIN + NEW.total := item__total__0614a79a__computed(NEW); + RETURN NEW; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE TRIGGER item__computed_trigger BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE item__exec_computed_fns(); +=== + +[ + {"Model":"Identity","Field":"","Type":"ADDED"}, + {"Model":"Item","Field":"","Type":"ADDED"}, + {"Model":"KeelAudit","Field":"","Type":"ADDED"} +] diff --git a/migrations/testdata/computed_field_multiple_depend.txt b/migrations/testdata/computed_field_multiple_depend.txt new file mode 100644 index 000000000..de448f784 --- /dev/null +++ b/migrations/testdata/computed_field_multiple_depend.txt @@ -0,0 +1,42 @@ + +model Item { + fields { + price Decimal + quantity Number + totalWithShipping Decimal + total Decimal + } +} + +=== + +model Item { + fields { + price Decimal + quantity Number + totalWithShipping Decimal @computed(item.total + 5) + total Decimal @computed(item.quantity * item.price) + } +} + +=== + +CREATE FUNCTION item__total__0614a79a__computed(r item) RETURNS NUMERIC AS $$ BEGIN + RETURN r."quantity" * r."price"; +END; $$ LANGUAGE plpgsql; +CREATE FUNCTION item__total_with_shipping__53d0d09b__computed(r item) RETURNS NUMERIC AS $$ BEGIN + RETURN r."total" + 5; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE FUNCTION item__exec_computed_fns() RETURNS TRIGGER AS $$ BEGIN + NEW.total := item__total__0614a79a__computed(NEW); + NEW.total_with_shipping := item__total_with_shipping__53d0d09b__computed(NEW); + RETURN NEW; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE TRIGGER item__computed_trigger BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE item__exec_computed_fns(); + +=== + +[ + {"Model":"Item","Field":"total","Type":"MODIFIED"}, + {"Model":"Item","Field":"totalWithShipping","Type":"MODIFIED"} +] diff --git a/migrations/testdata/computed_field_removed_attr.txt b/migrations/testdata/computed_field_removed_attr.txt new file mode 100644 index 000000000..9cd5669a5 --- /dev/null +++ b/migrations/testdata/computed_field_removed_attr.txt @@ -0,0 +1,29 @@ +model Item { + fields { + price Decimal + quantity Number + total Decimal @computed(item.quantity * item.price) + } +} + +=== + +model Item { + fields { + price Decimal + quantity Number + total Decimal + } +} + +=== + +DROP FUNCTION item__total__0614a79a__computed; +DROP TRIGGER item__computed_trigger ON item; +DROP FUNCTION item__exec_computed_fns; + +=== + +[ + {"Model":"Item","Field":"total","Type":"MODIFIED"} +] diff --git a/migrations/testdata/computed_field_removed_field.txt b/migrations/testdata/computed_field_removed_field.txt new file mode 100644 index 000000000..5df57c56e --- /dev/null +++ b/migrations/testdata/computed_field_removed_field.txt @@ -0,0 +1,29 @@ +model Item { + fields { + price Decimal + quantity Number + total Decimal @computed(item.quantity * item.price) + } +} + +=== + +model Item { + fields { + price Decimal + quantity Number + } +} + +=== + +ALTER TABLE "item" DROP COLUMN "total"; +DROP FUNCTION item__total__0614a79a__computed; +DROP TRIGGER item__computed_trigger ON item; +DROP FUNCTION item__exec_computed_fns; + +=== + +[ + {"Model":"Item","Field":"total","Type":"REMOVED"} +] diff --git a/migrations/testdata/computed_field_renamed_field.txt b/migrations/testdata/computed_field_renamed_field.txt new file mode 100644 index 000000000..b16c52a4c --- /dev/null +++ b/migrations/testdata/computed_field_renamed_field.txt @@ -0,0 +1,38 @@ +model Item { + fields { + price Decimal + quantity Number + total Decimal @computed(item.quantity * item.price) + } +} + +=== + +model Item { + fields { + price Decimal + quantity Number + newTotal Decimal @computed(item.quantity * item.price) + } +} + +=== + +ALTER TABLE "item" ADD COLUMN "new_total" NUMERIC NOT NULL; +ALTER TABLE "item" DROP COLUMN "total"; +CREATE FUNCTION item__new_total__0614a79a__computed(r item) RETURNS NUMERIC AS $$ BEGIN + RETURN r."quantity" * r."price"; +END; $$ LANGUAGE plpgsql; +DROP FUNCTION item__total__0614a79a__computed; +CREATE OR REPLACE FUNCTION item__exec_computed_fns() RETURNS TRIGGER AS $$ BEGIN + NEW.new_total := item__new_total__0614a79a__computed(NEW); + RETURN NEW; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE TRIGGER item__computed_trigger BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE item__exec_computed_fns(); + +=== + +[ + {"Model":"Item","Field":"newTotal","Type":"ADDED"}, + {"Model":"Item","Field":"total","Type":"REMOVED"} +] diff --git a/migrations/testdata/computed_field_unchanged.txt b/migrations/testdata/computed_field_unchanged.txt new file mode 100644 index 000000000..e7f97fc48 --- /dev/null +++ b/migrations/testdata/computed_field_unchanged.txt @@ -0,0 +1,24 @@ +model Item { + fields { + price Decimal + quantity Number + total Decimal @computed(item.quantity * item.price) + } +} + +=== + +model Item { + fields { + price Decimal + quantity Number + total Decimal @computed(item.quantity * item.price) + } +} + +=== + +=== + +[] + diff --git a/node/codegen.go b/node/codegen.go index 6c48b4f7a..7ff37f16f 100644 --- a/node/codegen.go +++ b/node/codegen.go @@ -342,7 +342,7 @@ func writeCreateValuesType(w *codegen.Writer, schema *proto.Schema, model *proto } w.Write(field.Name) - if field.Optional || field.DefaultValue != nil || field.IsHasMany() { + if field.Optional || field.DefaultValue != nil || field.IsHasMany() || field.ComputedExpression != nil { w.Write("?") } @@ -431,7 +431,7 @@ func writeFindManyParamsInterface(w *codegen.Writer, model *proto.Model) { switch f.Type.Type { // scalar types are only permitted to sort by - case proto.Type_TYPE_BOOL, proto.Type_TYPE_DATE, proto.Type_TYPE_DATETIME, proto.Type_TYPE_INT, proto.Type_TYPE_STRING, proto.Type_TYPE_ENUM, proto.Type_TYPE_TIMESTAMP, proto.Type_TYPE_ID: + case proto.Type_TYPE_BOOL, proto.Type_TYPE_DATE, proto.Type_TYPE_DATETIME, proto.Type_TYPE_INT, proto.Type_TYPE_STRING, proto.Type_TYPE_ENUM, proto.Type_TYPE_TIMESTAMP, proto.Type_TYPE_ID, proto.Type_TYPE_DECIMAL: return true default: // includes types such as password, secret, model etc @@ -678,7 +678,7 @@ func writeModelAPIDeclaration(w *codegen.Writer, model *proto.Model) { w.Indent() nonOptionalFields := lo.Filter(model.Fields, func(f *proto.Field, _ int) bool { - return !f.Optional && f.DefaultValue == nil + return !f.Optional && f.DefaultValue == nil && f.ComputedExpression == nil }) tsDocComment(w, func(w *codegen.Writer) { diff --git a/node/codegen_test.go b/node/codegen_test.go index f8499e135..01ad8e351 100644 --- a/node/codegen_test.go +++ b/node/codegen_test.go @@ -42,6 +42,7 @@ model Person { height Decimal bio Markdown file File + heightInMetres Decimal @computed(person.height * 0.3048) } }` @@ -59,6 +60,7 @@ export interface PersonTable { height: number bio: string file: FileDbRecord + heightInMetres: number id: Generated createdAt: Generated updatedAt: Generated @@ -106,6 +108,7 @@ export interface Person { height: number bio: string file: runtime.File + heightInMetres: number id: string createdAt: Date updatedAt: Date @@ -131,6 +134,7 @@ export type PersonCreateValues = { height: number bio: string file: runtime.InlineFile | runtime.File + heightInMetres?: number id?: string createdAt?: Date updatedAt?: Date @@ -180,6 +184,7 @@ export interface PersonWhereConditions { tags?: string[] | runtime.StringArrayWhereCondition; height?: number | runtime.NumberWhereCondition; bio?: string | runtime.StringWhereCondition; + heightInMetres?: number | runtime.NumberWhereCondition; id?: string | runtime.IDWhereCondition; createdAt?: Date | runtime.DateWhereCondition; updatedAt?: Date | runtime.DateWhereCondition; @@ -314,6 +319,8 @@ export type PersonOrderBy = { dateOfBirth?: runtime.SortDirection, gender?: runtime.SortDirection, hasChildren?: runtime.SortDirection, + height?: runtime.SortDirection, + heightInMetres?: runtime.SortDirection, id?: runtime.SortDirection, createdAt?: runtime.SortDirection, updatedAt?: runtime.SortDirection diff --git a/node/templates/client/core.ts b/node/templates/client/core.ts index 5e52f50a2..7b815213f 100644 --- a/node/templates/client/core.ts +++ b/node/templates/client/core.ts @@ -150,7 +150,9 @@ export class Core { /** * A promise that resolves when the session is refreshed. */ - refreshingPromise: undefined as Promise> | undefined, + refreshingPromise: undefined as + | Promise> + | undefined, /** * Returns data field set to the list of supported authentication providers and their SSO login URLs. @@ -324,10 +326,10 @@ export class Core { // If refreshing already, wait for the existing refreshing promisee if (!this.auth.refreshingPromise) { - this.auth.refreshingPromise = this.auth.requestToken({ - grant_type: "refresh_token", - refresh_token: refreshToken, - }); + this.auth.refreshingPromise = this.auth.requestToken({ + grant_type: "refresh_token", + refresh_token: refreshToken, + }); } const authResponse = await this.auth.refreshingPromise; diff --git a/proto/model.go b/proto/model.go index 40128e000..bf50547e5 100644 --- a/proto/model.go +++ b/proto/model.go @@ -46,3 +46,14 @@ func (m *Model) PrimaryKeyFieldName() string { } return "" } + +// GetComputedFields returns all the computed fields on the given model. +func (m *Model) GetComputedFields() []*Field { + fields := []*Field{} + for _, f := range m.Fields { + if f.ComputedExpression != nil { + fields = append(fields, f) + } + } + return fields +} diff --git a/runtime/actions/generate_computed.go b/runtime/actions/generate_computed.go new file mode 100644 index 000000000..793bd57cd --- /dev/null +++ b/runtime/actions/generate_computed.go @@ -0,0 +1,119 @@ +package actions + +import ( + "errors" + "fmt" + "regexp" + "strings" + + "github.com/google/cel-go/common/operators" + "github.com/iancoleman/strcase" + "github.com/teamkeel/keel/expressions/resolve" + "github.com/teamkeel/keel/proto" + + "github.com/teamkeel/keel/schema/parser" +) + +// GenerateComputedFunction visits the expression and generates a SQL expression +func GenerateComputedFunction(schema *proto.Schema, model *proto.Model, field *proto.Field) resolve.Visitor[string] { + return &computedQueryGen{ + schema: schema, + model: model, + field: field, + sql: "", + } +} + +var _ resolve.Visitor[string] = new(computedQueryGen) + +type computedQueryGen struct { + schema *proto.Schema + model *proto.Model + field *proto.Field + sql string +} + +func (v *computedQueryGen) StartCondition(nested bool) error { + if nested { + v.sql += "(" + } + return nil +} + +func (v *computedQueryGen) EndCondition(nested bool) error { + if nested { + v.sql += ")" + } + return nil +} + +func (v *computedQueryGen) VisitAnd() error { + v.sql += " AND " + return nil +} + +func (v *computedQueryGen) VisitOr() error { + v.sql += " OR " + return nil +} + +func (v *computedQueryGen) VisitNot() error { + v.sql += " NOT " + return nil +} + +func (v *computedQueryGen) VisitOperator(op string) error { + // Map CEL operators to SQL operators + sqlOp := map[string]string{ + operators.Add: "+", + operators.Subtract: "-", + operators.Multiply: "*", + operators.Divide: "/", + operators.Equals: "IS NOT DISTINCT FROM", + operators.NotEquals: "IS DISTINCT FROM", + operators.Greater: ">", + operators.GreaterEquals: ">=", + operators.Less: "<", + operators.LessEquals: "<=", + }[op] + + if sqlOp == "" { + return fmt.Errorf("unsupported operator: %s", op) + } + + v.sql += " " + sqlOp + " " + return nil +} + +func (v *computedQueryGen) VisitLiteral(value any) error { + switch val := value.(type) { + case int64: + v.sql += fmt.Sprintf("%v", val) + case float64: + v.sql += fmt.Sprintf("%v", val) + case string: + v.sql += fmt.Sprintf("\"%v\"", val) + case bool: + v.sql += fmt.Sprintf("%t", val) + case nil: + v.sql += "NULL" + default: + return fmt.Errorf("unsupported literal type: %T", value) + } + return nil +} + +func (v *computedQueryGen) VisitIdent(ident *parser.ExpressionIdent) error { + v.sql += "r." + sqlQuote(strcase.ToSnake(ident.Fragments[len(ident.Fragments)-1])) + return nil +} + +func (v *computedQueryGen) VisitIdentArray(idents []*parser.ExpressionIdent) error { + return errors.New("ident arrays not supported in computed expressions") +} + +func (v *computedQueryGen) Result() (string, error) { + // Remove multiple whitespaces and trim + re := regexp.MustCompile(`\s+`) + return re.ReplaceAllString(strings.TrimSpace(v.sql), " "), nil +} diff --git a/runtime/actions/generate_computed_test.go b/runtime/actions/generate_computed_test.go new file mode 100644 index 000000000..1f2405c2b --- /dev/null +++ b/runtime/actions/generate_computed_test.go @@ -0,0 +1,171 @@ +package actions_test + +import ( + "strings" + "testing" + + "github.com/teamkeel/keel/expressions/resolve" + "github.com/teamkeel/keel/proto" + "github.com/teamkeel/keel/runtime/actions" + "github.com/teamkeel/keel/schema" + "github.com/teamkeel/keel/schema/parser" + "github.com/teamkeel/keel/schema/reader" + "github.com/test-go/testify/assert" +) + +const testSchema = ` +model Item { + fields { + product Text + price Decimal? + quantity Number + isActive Boolean + #placeholder# + } +}` + +type computedTestCase struct { + // Name given to the test case + name string + // Valid keel schema for this test case + keelSchema string + // action name to run test upon + field string + // Input map for action + expectedSql string +} + +var computedTestCases = []computedTestCase{ + + { + name: "adding field with literal", + keelSchema: testSchema, + field: "total Decimal @computed(item.price + 100)", + expectedSql: `r."price" + 100`, + }, + { + name: "subtracting field with literal", + keelSchema: testSchema, + field: "total Decimal @computed(item.price - 100)", + expectedSql: `r."price" - 100`, + }, + { + name: "dividing field with literal", + keelSchema: testSchema, + field: "total Decimal @computed(item.price / 100)", + expectedSql: `r."price" / 100`, + }, + { + name: "multiplying field with literal", + keelSchema: testSchema, + field: "total Decimal @computed(item.price * 100)", + expectedSql: `r."price" * 100`, + }, + { + name: "multiply fields on same model", + keelSchema: testSchema, + field: "total Decimal @computed(item.price * item.quantity)", + expectedSql: `r."price" * r."quantity"`, + }, + { + name: "parenthesis", + keelSchema: testSchema, + field: "total Decimal @computed(item.quantity * (1 + item.quantity) / (100 * (item.price + 1)))", + expectedSql: `r."quantity" * (1 + r."quantity") / (100 * (r."price" + 1))`, + }, + { + name: "no parenthesis", + keelSchema: testSchema, + field: "total Decimal @computed(item.quantity * 1 + item.quantity / 100 * item.price + 1)", + expectedSql: `r."quantity" * 1 + r."quantity" / 100 * r."price" + 1`, + }, + { + name: "bool greater than", + keelSchema: testSchema, + field: "isExpensive Boolean @computed(item.price > 100)", + expectedSql: `r."price" > 100`, + }, + { + name: "bool greater or equals", + keelSchema: testSchema, + field: "isExpensive Boolean @computed(item.price >= 100)", + expectedSql: `r."price" >= 100`, + }, + { + name: "bool less than", + keelSchema: testSchema, + field: "isCheap Boolean @computed(item.price < 100)", + expectedSql: `r."price" < 100`, + }, + { + name: "bool less or equals", + keelSchema: testSchema, + field: "isCheap Boolean @computed(item.price <= 100)", + expectedSql: `r."price" <= 100`, + }, + { + name: "bool is not null", + keelSchema: testSchema, + field: "hasPrice Boolean @computed(item.price != null)", + expectedSql: `r."price" IS DISTINCT FROM NULL`, + }, + { + name: "bool is null", + keelSchema: testSchema, + field: "noPrice Boolean @computed(item.price == null)", + expectedSql: `r."price" IS NOT DISTINCT FROM NULL`, + }, + { + name: "bool with and", + keelSchema: testSchema, + field: "isExpensive Boolean @computed(item.price > 100 && item.isActive)", + expectedSql: `r."price" > 100 AND r."is_active"`, + }, + { + name: "bool with or", + keelSchema: testSchema, + field: "isExpensive Boolean @computed(item.price > 100 || item.isActive)", + expectedSql: `(r."price" > 100 OR r."is_active")`, + }, + { + name: "negation", + keelSchema: testSchema, + field: "isExpensive Boolean @computed(item.price > 100 || !item.isActive)", + expectedSql: `(r."price" > 100 OR NOT r."is_active")`, + }, +} + +func TestGeneratedComputed(t *testing.T) { + t.Parallel() + for _, testCase := range computedTestCases { + t.Run(testCase.name, func(t *testing.T) { + raw := strings.Replace(testCase.keelSchema, "#placeholder#", testCase.field, 1) + + schemaFiles := + &reader.Inputs{ + SchemaFiles: []*reader.SchemaFile{ + { + Contents: raw, + FileName: "schema.keel", + }, + }, + } + + builder := &schema.Builder{} + schema, err := builder.MakeFromInputs(schemaFiles) + assert.NoError(t, err) + + model := schema.Models[0] + fieldName := strings.Split(testCase.field, " ")[0] + field := proto.FindField(schema.Models, model.Name, fieldName) + + expression, err := parser.ParseExpression(field.ComputedExpression.Source) + assert.NoError(t, err) + + sql, err := resolve.RunCelVisitor(expression, actions.GenerateComputedFunction(schema, model, field)) + assert.NoError(t, err) + + assert.Equal(t, testCase.expectedSql, sql, "expected `%s` but got `%s`", testCase.expectedSql, sql) + }) + } +} diff --git a/runtime/actions/generate_filter.go b/runtime/actions/generate_filter.go index ee900321d..7f2b8b52a 100644 --- a/runtime/actions/generate_filter.go +++ b/runtime/actions/generate_filter.go @@ -52,6 +52,7 @@ func (v *whereQueryGen) StartCondition(nested bool) error { return nil } + func (v *whereQueryGen) EndCondition(nested bool) error { if _, ok := v.operators.Peek(); ok && v.operands.Size() == 2 { operator, _ := v.operators.Pop() diff --git a/runtime/actions/query.go b/runtime/actions/query.go index 74c9e93ff..57869c9bf 100644 --- a/runtime/actions/query.go +++ b/runtime/actions/query.go @@ -1474,7 +1474,7 @@ func (query *QueryBuilder) generateConditionTemplate(lhs *QueryOperand, operator case AllGreaterThanEquals, AllOnOrAfter: template = fmt.Sprintf("%s <= ALL(%s)", rhsSqlOperand, lhsSqlOperand) - /* All relative date operators */ + /* Relative date operators */ case BeforeRelative: template = fmt.Sprintf("%s < %s", lhsSqlOperand, rhsSqlOperand) case AfterRelative: @@ -1498,6 +1498,7 @@ func (query *QueryBuilder) generateConditionTemplate(lhs *QueryOperand, operator } template = fmt.Sprintf("%s >= %s AND %s < %s", lhsSqlOperand, rhsSqlOperand, lhsSqlOperand, end) + default: return "", nil, fmt.Errorf("operator: %v is not yet supported", operator) } diff --git a/runtime/actions/query_test.go b/runtime/actions/query_test.go index afb2547d7..b5bacb377 100644 --- a/runtime/actions/query_test.go +++ b/runtime/actions/query_test.go @@ -3778,7 +3778,7 @@ var testCases = []testCase{ "thing" WHERE "thing"."id" IS NOT DISTINCT FROM ? AND - NOT (("thing"."is_active" IS NOT DISTINCT FROM ? OR "thing"."number" IS DISTINCT FROM ?))`, + NOT ("thing"."is_active" IS NOT DISTINCT FROM ? OR "thing"."number" IS DISTINCT FROM ?)`, expectedArgs: []any{"123", true, int64(0)}, }, } @@ -3787,6 +3787,7 @@ func TestQueryBuilder(t *testing.T) { t.Parallel() for _, testCase := range testCases { testCase := testCase + t.Run(testCase.name, func(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/runtime/actions/update.go b/runtime/actions/update.go index b798aa923..23d04e4e6 100644 --- a/runtime/actions/update.go +++ b/runtime/actions/update.go @@ -64,9 +64,6 @@ func Update(scope *Scope, input map[string]any) (res map[string]any, err error) return nil, common.NewPermissionError() } } - if err != nil { - return nil, err - } // Execute database request, expecting a single result res, err = statement.ExecuteToSingle(scope.Context) diff --git a/schema/attributes/default_test.go b/schema/attributes/default_test.go index 4c4513dec..39810617c 100644 --- a/schema/attributes/default_test.go +++ b/schema/attributes/default_test.go @@ -97,7 +97,7 @@ func TestDefault_ValidNumber(t *testing.T) { require.Empty(t, issues) } -func TestDefault_InvalidNumber(t *testing.T) { +func TestDefault_ValidNumberFromDecimal(t *testing.T) { schema := parse(t, &reader.SchemaFile{FileName: "test.keel", Contents: ` model Person { fields { @@ -111,8 +111,24 @@ func TestDefault_InvalidNumber(t *testing.T) { issues, err := attributes.ValidateDefaultExpression(schema, field, expression) require.NoError(t, err) - require.Len(t, issues, 1) - require.Equal(t, "expression expected to resolve to type Number but it is Decimal", issues[0].Message) + require.Len(t, issues, 0) +} + +func TestDefault_ValidDecimalFromNumber(t *testing.T) { + schema := parse(t, &reader.SchemaFile{FileName: "test.keel", Contents: ` + model Person { + fields { + age Decimal @default(1) + } + }`}) + + model := query.Model(schema, "Person") + field := query.Field(model, "age") + expression := field.Attributes[0].Arguments[0].Expression + + issues, err := attributes.ValidateDefaultExpression(schema, field, expression) + require.NoError(t, err) + require.Len(t, issues, 0) } func TestDefault_ValidID(t *testing.T) { diff --git a/schema/completions/completions.go b/schema/completions/completions.go index c2f56482b..de1ec44b4 100644 --- a/schema/completions/completions.go +++ b/schema/completions/completions.go @@ -254,6 +254,7 @@ func getBlockCompletions(asts []*parser.AST, tokenAtPos *TokensAtPosition, keywo parser.AttributeUnique, parser.AttributeDefault, parser.AttributeRelation, + parser.AttributeComputed, }) } @@ -323,6 +324,7 @@ func getBlockCompletions(asts []*parser.AST, tokenAtPos *TokensAtPosition, keywo parser.AttributeUnique, parser.AttributeDefault, parser.AttributeRelation, + parser.AttributeComputed, }) } @@ -644,7 +646,7 @@ func getAttributeArgCompletions(asts []*parser.AST, t *TokensAtPosition, cfg *co enclosingBlock := getTypeOfEnclosingBlock(t) switch attrName { - case parser.AttributeSet, parser.AttributeWhere, parser.AttributeValidate: + case parser.AttributeSet, parser.AttributeWhere, parser.AttributeValidate, parser.AttributeComputed: return getExpressionCompletions(asts, t, cfg) case parser.AttributePermission: return getPermissionArgCompletions(asts, t, cfg) diff --git a/schema/completions/completions_test.go b/schema/completions/completions_test.go index 3a7e64e64..2d42b33be 100644 --- a/schema/completions/completions_test.go +++ b/schema/completions/completions_test.go @@ -434,7 +434,7 @@ func TestFieldCompletions(t *testing.T) { } } }`, - expected: []string{"@unique", "@default", "@relation"}, + expected: []string{"@unique", "@default", "@relation", "@computed"}, }, { name: "field-attributes-bare-at", @@ -443,7 +443,7 @@ func TestFieldCompletions(t *testing.T) { name Text @ } }`, - expected: []string{"@unique", "@default", "@relation"}, + expected: []string{"@unique", "@default", "@relation", "@computed"}, }, { name: "field-attributes-whitespace", @@ -453,7 +453,7 @@ func TestFieldCompletions(t *testing.T) { name Text } }`, - expected: []string{"@unique", "@default", "@relation"}, + expected: []string{"@unique", "@default", "@relation", "@computed"}, }, } @@ -1089,6 +1089,37 @@ func TestSetAttributeCompletions(t *testing.T) { runTestsCases(t, cases) } +func TestComputedAttributeCompletions(t *testing.T) { + cases := []testCase{ + { + name: "computed-attribute-operands", + schema: ` + model Item { + fields { + price Decimal + quantity Decimal + total Decimal @computed() + } + }`, + expected: []string{"ctx", "item"}, + }, + { + name: "computed-attribute-model-fields", + schema: ` + model Item { + fields { + price Decimal + quantity Decimal + total Decimal @computed(item.) + } + }`, + expected: []string{"createdAt", "id", "price", "quantity", "total", "updatedAt"}, + }, + } + + runTestsCases(t, cases) +} + func TestFunctionCompletions(t *testing.T) { cases := []testCase{ // name tests diff --git a/schema/parser/expressions.go b/schema/parser/expressions.go index 264ae2b47..7fab3cfdb 100644 --- a/schema/parser/expressions.go +++ b/schema/parser/expressions.go @@ -49,27 +49,6 @@ func (e *Expression) Parse(lex *lexer.PeekingLexer) error { } } -// func (e *Expression) String() string { -// if len(e.Tokens) == 0 { -// return "" -// } - -// var result strings.Builder -// currentColumn := e.Pos.Column - -// // Handle tokens without preserving line breaks -// for _, token := range e.Tokens { -// // Add spaces to reach the correct column position -// if token.Pos.Column > currentColumn { -// result.WriteString(strings.Repeat(" ", token.Pos.Column-currentColumn)) -// } -// result.WriteString(token.Value) -// currentColumn = token.Pos.Column + len(token.Value) -// } - -// return result.String() -// } - func (e *Expression) String() string { if len(e.Tokens) == 0 { return "" diff --git a/schema/query/query.go b/schema/query/query.go index 9542ad634..676a47079 100644 --- a/schema/query/query.go +++ b/schema/query/query.go @@ -350,6 +350,10 @@ func FieldIsUnique(field *parser.FieldNode) bool { return FieldHasAttribute(field, parser.AttributePrimaryKey) || FieldHasAttribute(field, parser.AttributeUnique) } +func FieldIsComputed(field *parser.FieldNode) bool { + return FieldHasAttribute(field, parser.AttributeComputed) +} + // CompositeUniqueFields returns the model's fields that make up a composite unique attribute func CompositeUniqueFields(model *parser.ModelNode, attribute *parser.AttributeNode) []*parser.FieldNode { if attribute.Name.Value != parser.AttributeUnique { diff --git a/schema/testdata/errors/attribute_computed.keel b/schema/testdata/errors/attribute_computed.keel index 8c6e617bf..9c9282ae8 100644 --- a/schema/testdata/errors/attribute_computed.keel +++ b/schema/testdata/errors/attribute_computed.keel @@ -2,6 +2,9 @@ model Item { fields { //expect-error:23:32:AttributeArgumentError:0 argument(s) provided to @computed but expected 1 total Decimal @computed + //expect-error:26:47:AttributeNotAllowedError:@computed cannot be used on repeated fields + //expect-error:36:46:AttributeExpressionError:expression expected to resolve to type Decimal[] but it is Decimal + totals Decimal[] @computed(item.total) //expect-error:19:36:AttributeNotAllowedError:@computed cannot be used on field of type File file File @computed("file") //expect-error:23:42:AttributeNotAllowedError:@computed cannot be used on field of type Vector diff --git a/schema/testdata/errors/attribute_computed_expression.keel b/schema/testdata/errors/attribute_computed_expression.keel index a09b7c3d1..c4114046c 100644 --- a/schema/testdata/errors/attribute_computed_expression.keel +++ b/schema/testdata/errors/attribute_computed_expression.keel @@ -21,7 +21,8 @@ model Item { ctx Boolean @computed(ctx.isAuthenticated) //expect-error:27:50:AttributeNotAllowedError:@computed cannot be used on field of type Identity identity Identity @computed(ctx.identity) - + //expect-error:33:43:AttributeArgumentError:@computed expressions cannot reference itself + total Decimal @computed(item.total * 5) } actions { get getItem(id) diff --git a/schema/testdata/proto/array_fields/proto.json b/schema/testdata/proto/array_fields/proto.json index 49c708cbf..0a06f4af3 100644 --- a/schema/testdata/proto/array_fields/proto.json +++ b/schema/testdata/proto/array_fields/proto.json @@ -103,9 +103,7 @@ "type": "TYPE_STRING" }, "optional": true, - "uniqueWith": [ - "issuer" - ] + "uniqueWith": ["issuer"] }, { "modelName": "Identity", @@ -142,9 +140,7 @@ "type": "TYPE_STRING" }, "optional": true, - "uniqueWith": [ - "email" - ] + "uniqueWith": ["email"] }, { "modelName": "Identity", @@ -376,9 +372,7 @@ "fieldName": "texts", "repeated": true }, - "target": [ - "texts" - ] + "target": ["texts"] }, { "messageName": "CreateThingInput", @@ -389,9 +383,7 @@ "fieldName": "numbers", "repeated": true }, - "target": [ - "numbers" - ] + "target": ["numbers"] }, { "messageName": "CreateThingInput", @@ -402,9 +394,7 @@ "fieldName": "booleans", "repeated": true }, - "target": [ - "booleans" - ] + "target": ["booleans"] }, { "messageName": "CreateThingInput", @@ -415,9 +405,7 @@ "fieldName": "dates", "repeated": true }, - "target": [ - "dates" - ] + "target": ["dates"] }, { "messageName": "CreateThingInput", @@ -428,9 +416,7 @@ "fieldName": "timestamps", "repeated": true }, - "target": [ - "timestamps" - ] + "target": ["timestamps"] } ] }, @@ -1093,9 +1079,7 @@ "type": "TYPE_MESSAGE", "messageName": "StringArrayQueryInput" }, - "target": [ - "texts" - ] + "target": ["texts"] }, { "messageName": "ListThingsWhere", @@ -1104,9 +1088,7 @@ "type": "TYPE_MESSAGE", "messageName": "IntArrayQueryInput" }, - "target": [ - "numbers" - ] + "target": ["numbers"] }, { "messageName": "ListThingsWhere", @@ -1115,9 +1097,7 @@ "type": "TYPE_MESSAGE", "messageName": "BooleanArrayQueryInput" }, - "target": [ - "booleans" - ] + "target": ["booleans"] }, { "messageName": "ListThingsWhere", @@ -1126,9 +1106,7 @@ "type": "TYPE_MESSAGE", "messageName": "DateArrayQueryInput" }, - "target": [ - "dates" - ] + "target": ["dates"] }, { "messageName": "ListThingsWhere", @@ -1137,9 +1115,7 @@ "type": "TYPE_MESSAGE", "messageName": "TimestampArrayQueryInput" }, - "target": [ - "timestamps" - ] + "target": ["timestamps"] } ] }, @@ -1189,4 +1165,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/schema/testdata/proto/attribute_computed/proto.json b/schema/testdata/proto/attribute_computed/proto.json new file mode 100644 index 000000000..980149a3c --- /dev/null +++ b/schema/testdata/proto/attribute_computed/proto.json @@ -0,0 +1,363 @@ +{ + "models": [ + { + "name": "Item", + "fields": [ + { + "modelName": "Item", + "name": "price", + "type": { + "type": "TYPE_DECIMAL" + } + }, + { + "modelName": "Item", + "name": "units", + "type": { + "type": "TYPE_DECIMAL" + } + }, + { + "modelName": "Item", + "name": "total", + "type": { + "type": "TYPE_DECIMAL" + }, + "computedExpression": { + "source": "item.price * item.units" + } + }, + { + "modelName": "Item", + "name": "id", + "type": { + "type": "TYPE_ID" + }, + "unique": true, + "primaryKey": true, + "defaultValue": { + "useZeroValue": true + } + }, + { + "modelName": "Item", + "name": "createdAt", + "type": { + "type": "TYPE_DATETIME" + }, + "defaultValue": { + "useZeroValue": true + } + }, + { + "modelName": "Item", + "name": "updatedAt", + "type": { + "type": "TYPE_DATETIME" + }, + "defaultValue": { + "useZeroValue": true + } + } + ], + "actions": [ + { + "modelName": "Item", + "name": "createItem", + "type": "ACTION_TYPE_CREATE", + "implementation": "ACTION_IMPLEMENTATION_AUTO", + "inputMessageName": "CreateItemInput" + } + ] + }, + { + "name": "Identity", + "fields": [ + { + "modelName": "Identity", + "name": "email", + "type": { + "type": "TYPE_STRING" + }, + "optional": true, + "uniqueWith": ["issuer"] + }, + { + "modelName": "Identity", + "name": "emailVerified", + "type": { + "type": "TYPE_BOOL" + }, + "defaultValue": { + "expression": { + "source": "false" + } + } + }, + { + "modelName": "Identity", + "name": "password", + "type": { + "type": "TYPE_PASSWORD" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "externalId", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "issuer", + "type": { + "type": "TYPE_STRING" + }, + "optional": true, + "uniqueWith": ["email"] + }, + { + "modelName": "Identity", + "name": "name", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "givenName", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "familyName", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "middleName", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "nickName", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "profile", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "picture", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "website", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "gender", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "zoneInfo", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "locale", + "type": { + "type": "TYPE_STRING" + }, + "optional": true + }, + { + "modelName": "Identity", + "name": "id", + "type": { + "type": "TYPE_ID" + }, + "unique": true, + "primaryKey": true, + "defaultValue": { + "useZeroValue": true + } + }, + { + "modelName": "Identity", + "name": "createdAt", + "type": { + "type": "TYPE_DATETIME" + }, + "defaultValue": { + "useZeroValue": true + } + }, + { + "modelName": "Identity", + "name": "updatedAt", + "type": { + "type": "TYPE_DATETIME" + }, + "defaultValue": { + "useZeroValue": true + } + } + ], + "actions": [ + { + "modelName": "Identity", + "name": "requestPasswordReset", + "type": "ACTION_TYPE_WRITE", + "implementation": "ACTION_IMPLEMENTATION_RUNTIME", + "inputMessageName": "RequestPasswordResetInput", + "responseMessageName": "RequestPasswordResetResponse" + }, + { + "modelName": "Identity", + "name": "resetPassword", + "type": "ACTION_TYPE_WRITE", + "implementation": "ACTION_IMPLEMENTATION_RUNTIME", + "inputMessageName": "ResetPasswordInput", + "responseMessageName": "ResetPasswordResponse" + } + ] + } + ], + "apis": [ + { + "name": "Api", + "apiModels": [ + { + "modelName": "Item", + "modelActions": [ + { + "actionName": "createItem" + } + ] + }, + { + "modelName": "Identity", + "modelActions": [ + { + "actionName": "requestPasswordReset" + }, + { + "actionName": "resetPassword" + } + ] + } + ] + } + ], + "messages": [ + { + "name": "Any" + }, + { + "name": "RequestPasswordResetInput", + "fields": [ + { + "messageName": "RequestPasswordResetInput", + "name": "email", + "type": { + "type": "TYPE_STRING" + } + }, + { + "messageName": "RequestPasswordResetInput", + "name": "redirectUrl", + "type": { + "type": "TYPE_STRING" + } + } + ] + }, + { + "name": "RequestPasswordResetResponse" + }, + { + "name": "ResetPasswordInput", + "fields": [ + { + "messageName": "ResetPasswordInput", + "name": "token", + "type": { + "type": "TYPE_STRING" + } + }, + { + "messageName": "ResetPasswordInput", + "name": "password", + "type": { + "type": "TYPE_STRING" + } + } + ] + }, + { + "name": "ResetPasswordResponse" + }, + { + "name": "CreateItemInput", + "fields": [ + { + "messageName": "CreateItemInput", + "name": "price", + "type": { + "type": "TYPE_DECIMAL", + "modelName": "Item", + "fieldName": "price" + }, + "target": ["price"] + }, + { + "messageName": "CreateItemInput", + "name": "units", + "type": { + "type": "TYPE_DECIMAL", + "modelName": "Item", + "fieldName": "units" + }, + "target": ["units"] + } + ] + } + ] +} diff --git a/schema/testdata/proto/attribute_computed/schema.keel b/schema/testdata/proto/attribute_computed/schema.keel new file mode 100644 index 000000000..15ea81375 --- /dev/null +++ b/schema/testdata/proto/attribute_computed/schema.keel @@ -0,0 +1,10 @@ +model Item { + fields { + price Decimal + units Decimal + total Decimal @computed(item.price * item.units) + } + actions { + create createItem() with (price, units) + } +} \ No newline at end of file diff --git a/schema/testdata/proto/computed_fields/proto.json b/schema/testdata/proto/computed_fields/proto.json index ebab0a8b6..6046dd4b5 100644 --- a/schema/testdata/proto/computed_fields/proto.json +++ b/schema/testdata/proto/computed_fields/proto.json @@ -146,9 +146,7 @@ "type": "TYPE_STRING" }, "optional": true, - "uniqueWith": [ - "issuer" - ] + "uniqueWith": ["issuer"] }, { "modelName": "Identity", @@ -185,9 +183,7 @@ "type": "TYPE_STRING" }, "optional": true, - "uniqueWith": [ - "email" - ] + "uniqueWith": ["email"] }, { "modelName": "Identity", @@ -403,4 +399,4 @@ "name": "ResetPasswordResponse" } ] -} \ No newline at end of file +} diff --git a/schema/validation/computed_attribute.go b/schema/validation/computed_attribute.go index 6d6101ded..304ca3cb6 100644 --- a/schema/validation/computed_attribute.go +++ b/schema/validation/computed_attribute.go @@ -3,8 +3,10 @@ package validation import ( "fmt" + "github.com/teamkeel/keel/expressions/resolve" "github.com/teamkeel/keel/schema/attributes" "github.com/teamkeel/keel/schema/parser" + "github.com/teamkeel/keel/schema/query" "github.com/teamkeel/keel/schema/validation/errorhandling" ) @@ -32,13 +34,9 @@ func ComputedAttributeRules(asts []*parser.AST, errs *errorhandling.ValidationEr } switch field.Type.Value { - case parser.FieldTypeID, - parser.FieldTypeText, - parser.FieldTypeBoolean, + case parser.FieldTypeBoolean, parser.FieldTypeNumber, - parser.FieldTypeDecimal, - parser.FieldTypeDate, - parser.FieldTypeTimestamp: + parser.FieldTypeDecimal: attribute = attr default: errs.AppendError( @@ -52,6 +50,18 @@ func ComputedAttributeRules(asts []*parser.AST, errs *errorhandling.ValidationEr ) } + if field.Repeated { + errs.AppendError( + errorhandling.NewValidationErrorWithDetails( + errorhandling.AttributeNotAllowedError, + errorhandling.ErrorDetails{ + Message: "@computed cannot be used on repeated fields", + }, + attr, + ), + ) + } + if len(attr.Arguments) != 1 { errs.AppendError( errorhandling.NewValidationErrorWithDetails( @@ -80,6 +90,7 @@ func ComputedAttributeRules(asts []*parser.AST, errs *errorhandling.ValidationEr Message: "expression could not be parsed", }, expression)) + return } if len(issues) > 0 { @@ -87,6 +98,31 @@ func ComputedAttributeRules(asts []*parser.AST, errs *errorhandling.ValidationEr errs.AppendError(issue) } } + + operands, err := resolve.IdentOperands(expression) + if err != nil { + return + } + + for _, operand := range operands { + if len(operand.Fragments) < 2 { + continue + } + + f := query.Field(model, operand.Fragments[1]) + + if f == field { + errs.AppendError( + errorhandling.NewValidationErrorWithDetails( + errorhandling.AttributeArgumentError, + errorhandling.ErrorDetails{ + Message: "@computed expressions cannot reference itself", + }, + operand, + ), + ) + } + } }, } } diff --git a/schema/validation/rules/actions/create_required.go b/schema/validation/rules/actions/create_required.go index 1ddb3eabe..f2f47379e 100644 --- a/schema/validation/rules/actions/create_required.go +++ b/schema/validation/rules/actions/create_required.go @@ -65,13 +65,15 @@ func checkField( // - relationship repeated fields // - fields which have a default // - built-in fields like CreatedAt, Id etc. +// - computed fields func isNotNeeded(asts []*parser.AST, model *parser.ModelNode, f *parser.FieldNode) bool { switch { case f.Optional, (f.Repeated && !f.IsScalar()), query.FieldHasAttribute(f, parser.AttributeDefault), query.IsBelongsToModelField(asts, model, f), - f.BuiltIn: + f.BuiltIn, + query.FieldIsComputed(f): return true default: return false