diff --git a/ndc-sqlgen/src/main/kotlin/io/hasura/ndc/sqlgen/BaseGenerator.kt b/ndc-sqlgen/src/main/kotlin/io/hasura/ndc/sqlgen/BaseGenerator.kt index b4d62ff..d2b8332 100644 --- a/ndc-sqlgen/src/main/kotlin/io/hasura/ndc/sqlgen/BaseGenerator.kt +++ b/ndc-sqlgen/src/main/kotlin/io/hasura/ndc/sqlgen/BaseGenerator.kt @@ -28,18 +28,26 @@ sealed interface BaseGenerator { fun buildComparison( col: Field, operator: ApplyBinaryComparisonOperator, - value: Field + listVal: List> ): Condition { + if (operator != ApplyBinaryComparisonOperator.IN && listVal.size != 1) { + error("Only the IN operator supports multiple values") + } + + // unwrap single value for use in all but the IN operator + // OR return falseCondition if listVal is empty + val singleVal = listVal.firstOrNull() ?: return DSL.falseCondition() + return when (operator) { - ApplyBinaryComparisonOperator.EQ -> col.eq(value) - ApplyBinaryComparisonOperator.GT -> col.gt(value) - ApplyBinaryComparisonOperator.GTE -> col.ge(value) - ApplyBinaryComparisonOperator.LT -> col.lt(value) - ApplyBinaryComparisonOperator.LTE -> col.le(value) - ApplyBinaryComparisonOperator.IN -> DSL.nullCondition() + ApplyBinaryComparisonOperator.EQ -> col.eq(singleVal) + ApplyBinaryComparisonOperator.GT -> col.gt(singleVal) + ApplyBinaryComparisonOperator.GTE -> col.ge(singleVal) + ApplyBinaryComparisonOperator.LT -> col.lt(singleVal) + ApplyBinaryComparisonOperator.LTE -> col.le(singleVal) + ApplyBinaryComparisonOperator.IN -> col.`in`(listVal) ApplyBinaryComparisonOperator.IS_NULL -> col.isNull - ApplyBinaryComparisonOperator.LIKE -> col.like(value as Field) - ApplyBinaryComparisonOperator.CONTAINS -> col.contains(value as Field) + ApplyBinaryComparisonOperator.LIKE -> col.like(singleVal as Field) + ApplyBinaryComparisonOperator.CONTAINS -> col.contains(singleVal as Field) } } @@ -124,15 +132,16 @@ sealed interface BaseGenerator { val comparisonValue = when (val v = e.value) { is ComparisonValue.ColumnComp -> { val col = splitCollectionName(getCollectionForCompCol(v.column, request)) - DSL.field(DSL.name(col + v.column.name)) + listOf(DSL.field(DSL.name(col + v.column.name))) } is ComparisonValue.ScalarComp -> - if(e.operator == ApplyBinaryComparisonOperator.IN) - return handleInComp(column, v) - else DSL.inline(v.value) + when (val scalarValue = v.value) { + is List<*> -> (scalarValue as List).map { DSL.inline(it) } + else -> listOf(DSL.inline(scalarValue)) + } - is ComparisonValue.VariableComp -> DSL.field(DSL.name(listOf("vars", v.name))) + is ComparisonValue.VariableComp -> listOf(DSL.field(DSL.name(listOf("vars", v.name)))) } return buildComparison(column, e.operator, comparisonValue) } @@ -212,17 +221,6 @@ sealed interface BaseGenerator { } } - fun handleInComp(column: Field, value: ComparisonValue.ScalarComp): Condition { - return when (val scalarValue = value.value) { - is List<*> -> { - if (scalarValue.isEmpty()) DSL.falseCondition() - else column.`in`(scalarValue.map { DSL.inline(it) }) - } - // Handle non-array scalar value - else -> column.eq(DSL.inline(scalarValue)) - } - } - fun splitCollectionName(collectionName: String): List { return collectionName.split(".") }