diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulLookupTableSource.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulLookupTableSource.java index 1c98e50bf..b4298b14c 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulLookupTableSource.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulLookupTableSource.java @@ -235,7 +235,7 @@ public DynamicTableSource copy() { lsts.projectedFields = this.projectedFields; lsts.remainingPartitions = this.remainingPartitions; lsts.filter = this.filter; - lsts.filterPlan = this.filterPlan; + lsts.filter = this.filter; return lsts; } diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulOneSplitRecordsReader.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulOneSplitRecordsReader.java index 40be9cb00..2e5b86704 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulOneSplitRecordsReader.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulOneSplitRecordsReader.java @@ -74,9 +74,8 @@ public class LakeSoulOneSplitRecordsReader implements RecordsWithSplitIds row, with requested schema private ArrowReader curArrowReaderRequestedSchema; - private final FilterPredicate filter; - - private final Plan filterPlan; + private final FilterPredicate filterPredicate; + private final Plan filter; public LakeSoulOneSplitRecordsReader(Configuration conf, LakeSoulSplit split, @@ -85,8 +84,8 @@ public LakeSoulOneSplitRecordsReader(Configuration conf, List pkColumns, boolean isStreaming, String cdcColumn, - FilterPredicate filter, - Plan filterPlan) + FilterPredicate filterPredicate, + Plan filter) throws Exception { this.split = split; this.skipRecords = split.getSkipRecord(); @@ -98,8 +97,8 @@ public LakeSoulOneSplitRecordsReader(Configuration conf, this.isStreaming = isStreaming; this.cdcColumn = cdcColumn; this.finishedSplit = Collections.singleton(splitId); + this.filterPredicate = filterPredicate; this.filter = filter; - this.filterPlan = filterPlan; initializeReader(); recoverFromSkipRecord(); } @@ -134,10 +133,11 @@ private void initializeReader() throws IOException { } if (filter != null) { - reader.addFilter(filter.toString()); + reader.addFilterProto(this.filter); } - if (filterPlan != null) { - reader.addFilterProto(this.filterPlan); + + if (filterPredicate !=null) { + reader.addFilter(filterPredicate.toString()); } LOG.info("Initializing reader for split {}, pk={}, partitions={}," + diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSource.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSource.java index e590d1031..eb0f2a733 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSource.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSource.java @@ -43,12 +43,10 @@ public class LakeSoulSource implements Source> remainingPartitions; - // TODO remove this @Nullable - FilterPredicate filter; - + FilterPredicate filterStr; @Nullable - Plan filterPlan; + Plan filter; public LakeSoulSource(TableId tableId, RowType rowType, @@ -57,8 +55,8 @@ public LakeSoulSource(TableId tableId, List pkColumns, Map optionParams, @Nullable List> remainingPartitions, - @Nullable FilterPredicate filter, - @Nullable Plan filterPlan) { + @Nullable FilterPredicate filterStr, + @Nullable Plan filter) { this.tableId = tableId; this.rowType = rowType; this.rowTypeWithPk = rowTypeWithPk; @@ -66,8 +64,9 @@ public LakeSoulSource(TableId tableId, this.pkColumns = pkColumns; this.optionParams = optionParams; this.remainingPartitions = remainingPartitions; + this.filterStr = filterStr; this.filter = filter; - this.filterPlan = filterPlan; + } @Override @@ -90,8 +89,8 @@ public SourceReader createReader(SourceReaderContext rea this.pkColumns, this.isStreaming, this.optionParams.getOrDefault(LakeSoulSinkOptions.CDC_CHANGE_COLUMN, ""), - this.filter, - this.filterPlan), + this.filterStr, + this.filter), new LakeSoulRecordEmitter(), readerContext.getConfiguration(), readerContext); diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSplitReader.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSplitReader.java index 49365b511..57fd486fa 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSplitReader.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSplitReader.java @@ -40,10 +40,8 @@ public class LakeSoulSplitReader implements SplitReader String cdcColumn; - // TODO remove this - FilterPredicate filter; - - Plan filterPlan; + FilterPredicate filterStr; + Plan filter; private LakeSoulOneSplitRecordsReader lastSplitReader; @@ -53,8 +51,8 @@ public LakeSoulSplitReader(Configuration conf, List pkColumns, boolean isStreaming, String cdcColumn, - FilterPredicate filter, - Plan filterPlan) { + FilterPredicate filterStr, + Plan filter) { this.conf = conf; this.splits = new ArrayDeque<>(); this.rowType = rowType; @@ -62,8 +60,8 @@ public LakeSoulSplitReader(Configuration conf, this.pkColumns = pkColumns; this.isStreaming = isStreaming; this.cdcColumn = cdcColumn; + this.filterStr = filterStr; this.filter = filter; - this.filterPlan = filterPlan; } @Override @@ -78,8 +76,8 @@ public RecordsWithSplitIds fetch() throws IOException { this.pkColumns, this.isStreaming, this.cdcColumn, - this.filter, - this.filterPlan + this.filterStr, + this.filter ); return lastSplitReader; } catch (Exception e) { diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitUtil.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitUtil.java index d0e7a1347..089df133d 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitUtil.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitUtil.java @@ -55,9 +55,10 @@ public static Tuple2 toP String tableSchema ) { List accepted = new ArrayList<>(); + Schema arrowSchema = toArrowSchema(tableSchema); Expression last = null; for (ResolvedExpression expr : exprs) { - Expression e = doTransform(expr); + Expression e = doTransform(expr,arrowSchema); if (e == null) { remaining.add(expr); } else { @@ -70,44 +71,51 @@ public static Tuple2 toP } } } - Plan filter = toFilter(last, tableName, tableSchema); + Plan filter = toFilter(last, tableName, arrowSchema); return Tuple2.of(SupportsFilterPushDown.Result.of(accepted, remaining), planToProto(filter)); } - static Plan toFilter(Expression e, String tableName, String tableSchema) { + static Schema toArrowSchema(String tableSchema) { try { - Schema arrow_schema = Schema.fromJSON(tableSchema); - List tableNames = Stream.of(tableName).collect(Collectors.toList()); - List columnNames = new ArrayList<>(); - List fields = arrow_schema.getFields(); - List columnTypes = new ArrayList<>(); - for (Field field : fields) { - Type type = fromArrowType(field.getType(), field.isNullable()); - if (type == null) { - return null; - } - columnTypes.add(type); - String name = field.getName(); - columnNames.add(name); + Schema arrowSchema = Schema.fromJSON(tableSchema); + return arrowSchema; + } catch (IOException e) { + // FIXME fix this elegantly + throw new RuntimeException(e); + } + } + + static Plan toFilter(Expression e, String tableName, Schema arrowSchema) { + if (e == null) { + return null; + } + List tableNames = Stream.of(tableName).collect(Collectors.toList()); + List columnNames = new ArrayList<>(); + List fields = arrowSchema.getFields(); + List columnTypes = new ArrayList<>(); + for (Field field : fields) { + Type type = fromArrowType(field.getType(), field.isNullable()); + if (type == null) { + return null; } - NamedScan namedScan = Builder.namedScan(tableNames, columnNames, columnTypes); - namedScan = - NamedScan.builder() - .from(namedScan) - .filter(e) - .build(); + columnTypes.add(type); + String name = field.getName(); + columnNames.add(name); + } + NamedScan namedScan = Builder.namedScan(tableNames, columnNames, columnTypes); + namedScan = + NamedScan.builder() + .from(namedScan) + .filter(e) + .build(); - Plan.Root root = Builder.root(namedScan); - return Builder.plan(root); - } catch (IOException ex) { - // FIXME fix this elegantly - throw new RuntimeException(ex); - } + Plan.Root root = Builder.root(namedScan); + return Builder.plan(root); } - public static Expression doTransform(ResolvedExpression flinkExpression) { - SubstraitVisitor substraitVisitor = new SubstraitVisitor(); + public static Expression doTransform(ResolvedExpression flinkExpression,Schema arrow_schema) { + SubstraitVisitor substraitVisitor = new SubstraitVisitor(arrow_schema); return flinkExpression.accept(substraitVisitor); } diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitVisitor.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitVisitor.java index 87ab6d75e..dce0744cc 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitVisitor.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitVisitor.java @@ -1,11 +1,14 @@ package org.apache.flink.lakesoul.substrait; +import com.alibaba.fastjson.util.BiFunction; import com.google.common.collect.ImmutableMap; import io.substrait.expression.*; import io.substrait.expression.Expression; import io.substrait.extension.SimpleExtension; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.flink.table.expressions.*; import org.apache.flink.table.expressions.ExpressionVisitor; import org.apache.flink.table.functions.BuiltInFunctionDefinitions; @@ -21,18 +24,25 @@ import java.time.Instant; import java.time.LocalDate; import java.util.*; -import java.util.function.Function; /** * return null means cannot convert */ public class SubstraitVisitor implements ExpressionVisitor { + + public SubstraitVisitor(Schema arrow_schema) { + this.arrowSchema = arrow_schema; + } + + private static final Logger LOG = LoggerFactory.getLogger(SubstraitVisitor.class); + private Schema arrowSchema; + @Override public Expression visit(CallExpression call) { - CallExprVisitor callVisitor = new CallExprVisitor(); + CallExprVisitor callVisitor = new CallExprVisitor(this.arrowSchema); return callVisitor.visit(call); } @@ -43,7 +53,7 @@ public Expression visit(ValueLiteralExpression valueLiteral) { @Override public Expression visit(FieldReferenceExpression fieldReference) { - return new FieldRefVisitor().visit(fieldReference); + return new FieldRefVisitor(this.arrowSchema).visit(fieldReference); } @Override @@ -91,7 +101,7 @@ public Expression.Literal visit(ValueLiteralExpression valueLiteral) { if (value != null) { s = (String) value; } - return ExpressionCreator.varChar(nullable, s, s.length()); + return ExpressionCreator.string(nullable, s); } case BOOLEAN: { boolean b = false; @@ -181,6 +191,11 @@ protected Expression.Literal defaultMethod(org.apache.flink.table.expressions.Ex } class FieldRefVisitor extends ExpressionDefaultVisitor { + public FieldRefVisitor(Schema arrow_schema) { + this.arrow_schema = arrow_schema; + } + + private Schema arrow_schema; private static final Logger LOG = LoggerFactory.getLogger(FieldRefVisitor.class); @@ -192,11 +207,27 @@ public FieldReference visit(FieldReferenceExpression fieldReference) { } LogicalType logicalType = fieldReference.getOutputDataType().getLogicalType(); LogicalTypeRoot typeRoot = logicalType.getTypeRoot(); + Type type = mapType(typeRoot, logicalType.isNullable()); + if (type == null) { + // not supported + return null; + } + String name = fieldReference.getName(); + int idx = 0; + List fields = arrow_schema.getFields(); + // find idx + for (int i = 0; i < fields.size(); i++) { + if (fields.get(i).getName().equals(name)) { + idx = i; + break; + } + } + return FieldReference.builder() - .type(Objects.requireNonNull(mapType(typeRoot, logicalType.isNullable()))) + .type(Objects.requireNonNull(type)) .addSegments( ImmutableStructField.builder() - .offset(fieldReference.getFieldIndex()) + .offset(idx) .build() ) .build(); @@ -208,39 +239,41 @@ protected FieldReference defaultMethod(org.apache.flink.table.expressions.Expres } private Type mapType(LogicalTypeRoot typeRoot, Boolean nullable) { + TypeCreator R = TypeCreator.of(nullable); switch (typeRoot) { case CHAR: case VARCHAR: { - return Type.VarChar.builder().nullable(nullable).build(); + // datafusion only support STRING + return R.STRING; } case BOOLEAN: { - return Type.Bool.builder().nullable(nullable).build(); + return R.BOOLEAN; } case BINARY: case VARBINARY: { - return Type.Binary.builder().nullable(nullable).build(); + return R.BINARY; } case TINYINT: case SMALLINT: case INTEGER: { - return Type.I32.builder().nullable(nullable).build(); + return R.I32; } case BIGINT: { - return Type.I64.builder().nullable(nullable).build(); + return R.I64; } case FLOAT: { - return Type.FP32.builder().nullable(nullable).build(); + return R.FP32; } case DOUBLE: { - return Type.FP64.builder().nullable(nullable).build(); + return R.FP64; } case DATE: { - return Type.Date.builder().nullable(nullable).build(); + return R.DATE; } case TIMESTAMP_WITHOUT_TIME_ZONE: case TIMESTAMP_WITH_TIME_ZONE: case TIMESTAMP_WITH_LOCAL_TIME_ZONE: { - return Type.Timestamp.builder().nullable(nullable).build(); + return R.TIMESTAMP; } default: LOG.info("unsupported type"); @@ -252,22 +285,28 @@ private Type mapType(LogicalTypeRoot typeRoot, Boolean nullable) { } class CallExprVisitor extends ExpressionDefaultVisitor { + + public CallExprVisitor(Schema arrowSchema) { + this.arrowSchema = arrowSchema; + } + + private Schema arrowSchema; private static final Logger LOG = LoggerFactory.getLogger(CallExprVisitor.class); - private static final ImmutableMap> + private static final ImmutableMap> FILTERS = new ImmutableMap.Builder< - FunctionDefinition, Function>() - .put(BuiltInFunctionDefinitions.IS_NULL, call -> makeUnaryFunction(call, "is_null:any", SubstraitUtil.CompNamespace)) - .put(BuiltInFunctionDefinitions.IS_NOT_NULL, call -> makeUnaryFunction(call, "is_not_null:any", SubstraitUtil.CompNamespace)) - .put(BuiltInFunctionDefinitions.NOT, call -> makeUnaryFunction(call, "not:bool", SubstraitUtil.BooleanNamespace)) - .put(BuiltInFunctionDefinitions.OR, call -> makeBinaryFunction(call, "or:bool", SubstraitUtil.BooleanNamespace)) - .put(BuiltInFunctionDefinitions.AND, call -> makeBinaryFunction(call, "and:bool", SubstraitUtil.BooleanNamespace)) - .put(BuiltInFunctionDefinitions.EQUALS, call -> makeBinaryFunction(call, "equal:any_any", SubstraitUtil.CompNamespace)) - .put(BuiltInFunctionDefinitions.NOT_EQUALS, call -> makeBinaryFunction(call, "not_equal:any_any", SubstraitUtil.CompNamespace)) - .put(BuiltInFunctionDefinitions.GREATER_THAN, call -> makeBinaryFunction(call, "gt:any_any", SubstraitUtil.CompNamespace)) - .put(BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL, call -> makeBinaryFunction(call, "gte:any_any", SubstraitUtil.CompNamespace)) - .put(BuiltInFunctionDefinitions.LESS_THAN, call -> makeBinaryFunction(call, "lt:any_any", SubstraitUtil.CompNamespace)) - .put(BuiltInFunctionDefinitions.LESS_THAN_OR_EQUAL, call -> makeBinaryFunction(call, "lte:any_any", SubstraitUtil.CompNamespace)) + FunctionDefinition, BiFunction>() + .put(BuiltInFunctionDefinitions.IS_NULL, (call, schema) -> makeUnaryFunction(call, schema, "is_null:any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.IS_NOT_NULL, (call,schema) -> makeUnaryFunction(call,schema, "is_not_null:any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.NOT, (call,schema) -> makeUnaryFunction(call,schema, "not:bool", SubstraitUtil.BooleanNamespace)) + .put(BuiltInFunctionDefinitions.OR, (call,schema) -> makeBinaryFunction(call,schema, "or:bool", SubstraitUtil.BooleanNamespace)) + .put(BuiltInFunctionDefinitions.AND, (call,schema) -> makeBinaryFunction(call,schema, "and:bool", SubstraitUtil.BooleanNamespace)) + .put(BuiltInFunctionDefinitions.EQUALS, (call,schema) -> makeBinaryFunction(call,schema, "equal:any_any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.NOT_EQUALS, (call,schema) -> makeBinaryFunction(call,schema, "not_equal:any_any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.GREATER_THAN, (call,schema) -> makeBinaryFunction(call,schema, "gt:any_any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL, (call,schema) -> makeBinaryFunction(call,schema, "gte:any_any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.LESS_THAN, (call,schema) -> makeBinaryFunction(call,schema, "lt:any_any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.LESS_THAN_OR_EQUAL, (call,schema)-> makeBinaryFunction(call,schema, "lte:any_any", SubstraitUtil.CompNamespace)) .build(); @Override @@ -279,13 +318,13 @@ public Expression visit(CallExpression call) { call); return null; } - return FILTERS.get(call.getFunctionDefinition()).apply(call); + return FILTERS.get(call.getFunctionDefinition()).apply(call, this.arrowSchema); } - static Expression makeBinaryFunction(CallExpression call, String funcKey, String namespace) { + static Expression makeBinaryFunction(CallExpression call, Schema arrow_schema, String funcKey, String namespace) { List children = call.getChildren(); assert children.size() == 2; - SubstraitVisitor visitor = new SubstraitVisitor(); + SubstraitVisitor visitor = new SubstraitVisitor(arrow_schema); Expression left = children.get(0).accept(visitor); Expression right = children.get(1).accept(visitor); if (left == null || right == null) { @@ -298,10 +337,10 @@ static Expression makeBinaryFunction(CallExpression call, String funcKey, String return ExpressionCreator.scalarFunction(func, TypeCreator.NULLABLE.BOOLEAN, args); } - static Expression makeUnaryFunction(CallExpression call, String funcKey, String namespace) { + static Expression makeUnaryFunction(CallExpression call, Schema arrow_schema, String funcKey, String namespace) { List children = call.getChildren(); assert children.size() == 1; - SubstraitVisitor visitor = new SubstraitVisitor(); + SubstraitVisitor visitor = new SubstraitVisitor(arrow_schema); Expression child = children.get(0).accept(visitor); if (child == null) { return null; diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/table/LakeSoulTableSource.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/table/LakeSoulTableSource.java index fc800c6a5..c03c903e9 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/table/LakeSoulTableSource.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/table/LakeSoulTableSource.java @@ -63,9 +63,9 @@ public class LakeSoulTableSource protected List> remainingPartitions; // TODO remove this - protected FilterPredicate filter; + protected FilterPredicate filterStr; // TODO merge - protected io.substrait.proto.Plan filterPlan; + protected io.substrait.proto.Plan filter; public LakeSoulTableSource(TableId tableId, RowType rowType, @@ -89,7 +89,7 @@ public DynamicTableSource copy() { lsts.projectedFields = this.projectedFields; lsts.remainingPartitions = this.remainingPartitions; lsts.filter = this.filter; - lsts.filterPlan = this.filterPlan; + lsts.filterStr = this.filterStr; return lsts; } @@ -110,23 +110,23 @@ public Result applyFilters(List filters) { DBUtil.TablePartitionKeys partitionKeys = DBUtil.parseTableInfoPartitions(tableInfo.getPartitions()); Set partitionCols = new HashSet<>(partitionKeys.rangeKeys); for (ResolvedExpression filter : filters) { - if (ParquetFilters.filterContainsPartitionColumn(filter, partitionCols)) { + if (SubstraitUtil.filterContainsPartitionColumn(filter, partitionCols)) { remainingFilters.add(filter); } else { nonPartitionFilters.add(filter); } } // find acceptable non partition filters - Tuple2 filterPushDownResult = ParquetFilters.toParquetFilter(nonPartitionFilters, + Tuple2 filterPushDownRes = ParquetFilters.toParquetFilter(nonPartitionFilters, remainingFilters); - Tuple2 filterPlanRes = SubstraitUtil.toPlan(nonPartitionFilters, + Tuple2 filterPushDownResult = SubstraitUtil.toPlan(nonPartitionFilters, remainingFilters, tableInfo.getTableName(), tableInfo.getTableSchema()); this.filter = filterPushDownResult.f1; - this.filterPlan = filterPlanRes.f1; + this.filterStr = filterPushDownRes.f1; LOG.info("Applied filters to native io: {}, accepted {}, remaining {}", this.filter, filterPushDownResult.f0.getAcceptedFilters(), filterPushDownResult.f0.getRemainingFilters()); - LOG.info("FilterPlan: {}", this.filterPlan); +// LOG.info("FilterPlan: {}", this.filterPlan); return filterPushDownResult.f0; } @@ -225,8 +225,8 @@ public ScanRuntimeProvider getScanRuntimeProvider(ScanContext runtimeProviderCon this.pkColumns, this.optionParams, this.remainingPartitions, - this.filter, - this.filterPlan)); + this.filterStr, + this.filter)); } @Override diff --git a/lakesoul-flink/src/test/java/org/apache/flink/lakesoul/test/substrait/SubstraitTest.java b/lakesoul-flink/src/test/java/org/apache/flink/lakesoul/test/substrait/SubstraitTest.java deleted file mode 100644 index af473d892..000000000 --- a/lakesoul-flink/src/test/java/org/apache/flink/lakesoul/test/substrait/SubstraitTest.java +++ /dev/null @@ -1,136 +0,0 @@ -package org.apache.flink.lakesoul.test.substrait; - -import com.dmetasoul.lakesoul.lakesoul.io.NativeIOReader; -import io.substrait.dsl.SubstraitBuilder; -import io.substrait.expression.*; -import io.substrait.expression.Expression; -import io.substrait.expression.proto.ExpressionProtoConverter; -import io.substrait.extension.ExtensionCollector; -import io.substrait.extension.SimpleExtension; -import io.substrait.plan.Plan; -import io.substrait.plan.PlanProtoConverter; -import io.substrait.relation.*; -import io.substrait.type.Type; -import io.substrait.type.TypeCreator; -import org.apache.flink.lakesoul.substrait.SubstraitUtil; -import org.apache.flink.table.expressions.*; -import org.apache.flink.table.functions.BuiltInFunctionDefinition; -import org.apache.flink.table.functions.BuiltInFunctionDefinitions; -import org.apache.flink.table.types.AtomicDataType; -import org.apache.flink.table.types.logical.BooleanType; -import org.apache.flink.table.types.logical.IntType; -import org.junit.Test; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -public class SubstraitTest { - @Test - public void generalExprTest() throws IOException { - ValueLiteralExpression valExpr = new ValueLiteralExpression(3, new AtomicDataType(new IntType(false))); - FieldReferenceExpression orderId = new FieldReferenceExpression("order_id", - new AtomicDataType(new IntType()) - , 0, 0); - List args = new ArrayList<>(); - args.add(orderId); - args.add(valExpr); - // naive binary func : gt gte lt lte - CallExpression eq = funcInvoke(BuiltInFunctionDefinitions.EQUALS, args); - CallExpression lt = funcInvoke(BuiltInFunctionDefinitions.LESS_THAN, args); - CallExpression lte = funcInvoke(BuiltInFunctionDefinitions.LESS_THAN_OR_EQUAL, args); - CallExpression gt = funcInvoke(BuiltInFunctionDefinitions.GREATER_THAN, args); - CallExpression gte = funcInvoke(BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL, args); - CallExpression and = funcInvoke(BuiltInFunctionDefinitions.AND, args); - CallExpression or = funcInvoke(BuiltInFunctionDefinitions.OR, args); - // naive unary func : is null , is not null - args.clear(); - args.add(orderId); - CallExpression nonNull = funcInvoke(BuiltInFunctionDefinitions.IS_NOT_NULL, args); - CallExpression isNull = funcInvoke(BuiltInFunctionDefinitions.IS_NULL, args); - // compound expr - args.clear(); - args.add(lt); - args.add(lte); - funcInvoke(BuiltInFunctionDefinitions.EQUALS, args); - args.clear(); - args.add(isNull); - funcInvoke(BuiltInFunctionDefinitions.NOT,args); - } - - CallExpression funcInvoke(BuiltInFunctionDefinition func, List args) { - CallExpression expr = CallExpression.permanent(func, args, new AtomicDataType(new BooleanType())); - Expression e = SubstraitUtil.doTransform(expr); - System.out.println(e); - return expr; - } - - @Test - public void literalExprTest() { - ValueLiteralExpression valExpr = new ValueLiteralExpression(3, new AtomicDataType(new IntType(false))); - Expression substraitExpr = SubstraitUtil.doTransform(valExpr); - System.out.println(substraitExpr); - } - - @Test - public void FieldRefTest() { - FieldReferenceExpression orderId = new FieldReferenceExpression("order_id", - new AtomicDataType(new IntType()) - , 0, 0); - Expression expr = SubstraitUtil.doTransform(orderId); - System.out.println(expr); - System.out.println(toProto(null, expr)); - } - - private io.substrait.proto.Expression toProto(ExtensionCollector collector, Expression expr) { - return expr.accept(new ExpressionProtoConverter(collector, null)); - } - - @Test - public void callExprTest() { - try { - SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.loadDefaults(); - System.out.println(extensionCollection.scalarFunctions()); - SimpleExtension.ScalarFunctionVariant desc = extensionCollection.getScalarFunction(SimpleExtension.FunctionAnchor.of("/functions_comparison.yaml", "equal:any_any")); - Expression.ScalarFunctionInvocation si = ExpressionCreator.scalarFunction(desc, TypeCreator.NULLABLE.I32); - io.substrait.proto.Expression p = toProto(new ExtensionCollector(), si); - System.out.println(p); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Test - public void endToEndTest() { - try { - SimpleExtension.ExtensionCollection se = SimpleExtension.loadDefaults(); - SubstraitBuilder b = new SubstraitBuilder(se); - List tableName = Stream.of("a_table").collect(Collectors.toList()); - List columnNames = Stream.of("col1", "col2").collect(Collectors.toList()); - TypeCreator R = TypeCreator.REQUIRED; - List columnTypes = Stream.of(R.I32, R.I32).collect(Collectors.toList()); - NamedScan namedScan = b.namedScan(tableName, columnNames, columnTypes); - namedScan = - NamedScan.builder() - .from(namedScan) - .filter(b.equal(b.fieldReference(namedScan, 1), b.i32(3))) - .build(); - - - Plan.Root root = b.root(namedScan); - Plan plan = b.plan(root); - System.out.println(plan); - PlanProtoConverter planProtoConverter = new PlanProtoConverter(); - io.substrait.proto.Plan proto = planProtoConverter.toProto(plan); - System.out.println(proto); -// byte[] byteArray = proto.toByteArray(); -// System.out.println(Arrays.toString(byteArray)); - NativeIOReader reader = new NativeIOReader(); -// reader.addFilterProto(proto); - } catch (IOException e) { - throw new RuntimeException(e); - } - } -}