diff --git a/lakesoul-flink/pom.xml b/lakesoul-flink/pom.xml index ab4a469c5..58b174103 100644 --- a/lakesoul-flink/pom.xml +++ b/lakesoul-flink/pom.xml @@ -443,6 +443,8 @@ SPDX-License-Identifier: Apache-2.0 com.google.code.gson:gson dev.failsafe:failsafe com.google.protobuf:protobuf-java + + io.substrait:core org.apache.logging.log4j:* diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulLookupTableSource.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulLookupTableSource.java index da5816bf5..964dac8ed 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulLookupTableSource.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulLookupTableSource.java @@ -234,6 +234,7 @@ public DynamicTableSource copy() { this.optionParams); lsts.projectedFields = this.projectedFields; lsts.remainingPartitions = this.remainingPartitions; + lsts._filterPredicate = this._filterPredicate; lsts.filter = this.filter; return lsts; } diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulOneSplitRecordsReader.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulOneSplitRecordsReader.java index 95997600b..a40adea13 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulOneSplitRecordsReader.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulOneSplitRecordsReader.java @@ -6,6 +6,7 @@ import com.dmetasoul.lakesoul.LakeSoulArrowReader; import com.dmetasoul.lakesoul.lakesoul.io.NativeIOReader; +import io.substrait.proto.Plan; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; @@ -73,7 +74,8 @@ public class LakeSoulOneSplitRecordsReader implements RecordsWithSplitIds row, with requested schema private ArrowReader curArrowReaderRequestedSchema; - private final FilterPredicate filter; + private final FilterPredicate _filterPredicate; + private final Plan filter; public LakeSoulOneSplitRecordsReader(Configuration conf, LakeSoulSplit split, @@ -82,7 +84,8 @@ public LakeSoulOneSplitRecordsReader(Configuration conf, List pkColumns, boolean isStreaming, String cdcColumn, - FilterPredicate filter) + FilterPredicate _filterPredicate, + Plan filter) throws Exception { this.split = split; this.skipRecords = split.getSkipRecord(); @@ -94,6 +97,7 @@ public LakeSoulOneSplitRecordsReader(Configuration conf, this.isStreaming = isStreaming; this.cdcColumn = cdcColumn; this.finishedSplit = Collections.singleton(splitId); + this._filterPredicate = _filterPredicate; this.filter = filter; initializeReader(); recoverFromSkipRecord(); @@ -129,7 +133,11 @@ private void initializeReader() throws IOException { } if (filter != null) { - reader.addFilter(filter.toString()); + reader.addFilterProto(this.filter); + } + + if (_filterPredicate !=null) { + reader.addFilter(_filterPredicate.toString()); } LOG.info("Initializing reader for split {}, pk={}, partitions={}," + diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSource.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSource.java index 6f10d9bee..cd4ba7c1f 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSource.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSource.java @@ -9,6 +9,7 @@ import com.dmetasoul.lakesoul.meta.DataOperation; import com.dmetasoul.lakesoul.meta.LakeSoulOptions; import com.dmetasoul.lakesoul.meta.entity.TableInfo; +import io.substrait.proto.Plan; import org.apache.flink.api.connector.source.*; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.Path; @@ -43,7 +44,9 @@ public class LakeSoulSource implements Source> remainingPartitions; @Nullable - FilterPredicate filter; + FilterPredicate _filterPredicate; + @Nullable + Plan filter; public LakeSoulSource(TableId tableId, RowType rowType, @@ -52,7 +55,8 @@ public LakeSoulSource(TableId tableId, List pkColumns, Map optionParams, @Nullable List> remainingPartitions, - @Nullable FilterPredicate filter) { + @Nullable FilterPredicate _filterPredicate, + @Nullable Plan filter) { this.tableId = tableId; this.rowType = rowType; this.rowTypeWithPk = rowTypeWithPk; @@ -60,7 +64,9 @@ public LakeSoulSource(TableId tableId, this.pkColumns = pkColumns; this.optionParams = optionParams; this.remainingPartitions = remainingPartitions; + this._filterPredicate = _filterPredicate; this.filter = filter; + } @Override @@ -83,6 +89,7 @@ public SourceReader createReader(SourceReaderContext rea this.pkColumns, this.isStreaming, this.optionParams.getOrDefault(LakeSoulSinkOptions.CDC_CHANGE_COLUMN, ""), + this._filterPredicate, this.filter), new LakeSoulRecordEmitter(), readerContext.getConfiguration(), diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSplitReader.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSplitReader.java index 518fb8ffa..51013e31b 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSplitReader.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/source/LakeSoulSplitReader.java @@ -4,6 +4,7 @@ package org.apache.flink.lakesoul.source; +import io.substrait.proto.Plan; import org.apache.flink.configuration.Configuration; import org.apache.flink.connector.base.source.reader.RecordsWithSplitIds; import org.apache.flink.connector.base.source.reader.splitreader.SplitReader; @@ -39,7 +40,8 @@ public class LakeSoulSplitReader implements SplitReader String cdcColumn; - FilterPredicate filter; + FilterPredicate _filterPredicate; + Plan filter; private LakeSoulOneSplitRecordsReader lastSplitReader; @@ -49,7 +51,8 @@ public LakeSoulSplitReader(Configuration conf, List pkColumns, boolean isStreaming, String cdcColumn, - FilterPredicate filter) { + FilterPredicate _filterPredicate, + Plan filter) { this.conf = conf; this.splits = new ArrayDeque<>(); this.rowType = rowType; @@ -57,6 +60,7 @@ public LakeSoulSplitReader(Configuration conf, this.pkColumns = pkColumns; this.isStreaming = isStreaming; this.cdcColumn = cdcColumn; + this._filterPredicate = _filterPredicate; this.filter = filter; } @@ -72,7 +76,9 @@ public RecordsWithSplitIds fetch() throws IOException { this.pkColumns, this.isStreaming, this.cdcColumn, - this.filter); + this._filterPredicate, + this.filter + ); return lastSplitReader; } catch (Exception e) { throw new IOException(e); diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitFlinkUtil.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitFlinkUtil.java new file mode 100644 index 000000000..bc689e937 --- /dev/null +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitFlinkUtil.java @@ -0,0 +1,68 @@ +package org.apache.flink.lakesoul.substrait; + +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.SimpleExtension; +import io.substrait.plan.Plan; +import io.substrait.type.TypeCreator; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown; +import org.apache.flink.table.expressions.CallExpression; +import org.apache.flink.table.expressions.FieldReferenceExpression; +import org.apache.flink.table.expressions.ResolvedExpression; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +import static com.dmetasoul.lakesoul.lakesoul.io.substrait.SubstraitUtil.*; + +public class SubstraitFlinkUtil { + + public static Tuple2 flinkExprToSubStraitPlan( + List exprs, + List remaining, + String tableName, + String tableSchema + ) throws IOException { + List accepted = new ArrayList<>(); + Schema arrowSchema = Schema.fromJSON(tableSchema); + Expression last = null; + for (ResolvedExpression expr : exprs) { + Expression e = doTransform(expr,arrowSchema); + if (e == null) { + remaining.add(expr); + } else { + accepted.add(expr); + if (last != null) { + SimpleExtension.FunctionAnchor fa = SimpleExtension.FunctionAnchor.of(BooleanNamespace, "and:bool"); + last = ExpressionCreator.scalarFunction(Se.getScalarFunction(fa), TypeCreator.NULLABLE.BOOLEAN, last, e); + } else { + last = e; + } + } + } + Plan filter = exprToFilter(last, tableName, arrowSchema); + return Tuple2.of(SupportsFilterPushDown.Result.of(accepted, remaining), planToProto(filter)); + } + + public static Expression doTransform(ResolvedExpression flinkExpression, Schema arrow_schema) { + SubstraitVisitor substraitVisitor = new SubstraitVisitor(arrow_schema); + return flinkExpression.accept(substraitVisitor); + } + + public static boolean filterContainsPartitionColumn(ResolvedExpression expression, Set partitionCols) { + if (expression instanceof FieldReferenceExpression) { + return partitionCols.contains(((FieldReferenceExpression) expression).getName()); + } else if (expression instanceof CallExpression) { + for (ResolvedExpression child : expression.getResolvedChildren()) { + if (filterContainsPartitionColumn(child, partitionCols)) { + return true; + } + } + } + return false; + } +} diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitVisitor.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitVisitor.java new file mode 100644 index 000000000..bcf8b9d5e --- /dev/null +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/substrait/SubstraitVisitor.java @@ -0,0 +1,359 @@ +package org.apache.flink.lakesoul.substrait; + +import com.dmetasoul.lakesoul.lakesoul.io.substrait.SubstraitUtil; +import com.google.common.collect.ImmutableMap; +import io.substrait.expression.*; +import io.substrait.expression.Expression; +import io.substrait.extension.SimpleExtension; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.flink.table.expressions.*; +import org.apache.flink.table.expressions.ExpressionVisitor; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.spark.sql.catalyst.util.DateTimeUtils$; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.util.*; +import java.util.function.BiFunction; + + +/** + * return null means cannot convert + */ +public class SubstraitVisitor implements ExpressionVisitor { + + public SubstraitVisitor(Schema arrow_schema) { + this.arrowSchema = arrow_schema; + } + + + private static final Logger LOG = LoggerFactory.getLogger(SubstraitVisitor.class); + + private Schema arrowSchema; + + @Override + public Expression visit(CallExpression call) { + CallExprVisitor callVisitor = new CallExprVisitor(this.arrowSchema); + return callVisitor.visit(call); + } + + @Override + public Expression visit(ValueLiteralExpression valueLiteral) { + return new LiteralVisitor().visit(valueLiteral); + } + + @Override + public Expression visit(FieldReferenceExpression fieldReference) { + return new FieldRefVisitor(this.arrowSchema).visit(fieldReference); + } + + @Override + public Expression visit(TypeLiteralExpression typeLiteral) { + LOG.error("not supported"); + return null; + } + + @Override + public Expression visit(org.apache.flink.table.expressions.Expression other) { + if (other instanceof CallExpression) { + return this.visit((CallExpression) other); + } else if (other instanceof ValueLiteralExpression) { + return this.visit((ValueLiteralExpression) other); + } else if (other instanceof FieldReferenceExpression) { + return this.visit((FieldReferenceExpression) other); + } else if (other instanceof TypeLiteralExpression) { + return this.visit((TypeLiteralExpression) other); + } else { + LOG.info("not supported"); + return null; + } + } +} + + +class LiteralVisitor extends ExpressionDefaultVisitor { + private static final Logger LOG = LoggerFactory.getLogger(LiteralVisitor.class); + + @Override + public Expression.Literal visit(ValueLiteralExpression valueLiteral) { + DataType dataType = valueLiteral.getOutputDataType(); + LogicalType logicalType = dataType.getLogicalType(); + Optional valueAs = valueLiteral.getValueAs(dataType.getConversionClass()); + Object value = null; + if (valueAs.isPresent()) { + value = valueAs.get(); + } + boolean nullable = logicalType.isNullable(); + LogicalTypeRoot typeRoot = logicalType.getTypeRoot(); + switch (typeRoot) { + case CHAR: + case VARCHAR: { + String s = ""; + if (value != null) { + s = (String) value; + } + return ExpressionCreator.string(nullable, s); + } + case BOOLEAN: { + boolean b = false; + if (value != null) { + b = (Boolean) value; + } + return ExpressionCreator.bool(nullable, b); + } + case BINARY: + case VARBINARY: { + byte[] b = new byte[]{}; + if (value != null) { + b = (byte[]) value; + } + return ExpressionCreator.binary(nullable, b); + } + case TINYINT: + case SMALLINT: + case INTEGER: { + int i = 0; + if (value != null) { + i = (int) value; + } + return ExpressionCreator.i32(nullable, i); + + } + case BIGINT: { + long l = 0; + if (value != null) { + l = (long) value; + } + return ExpressionCreator.i64(nullable, l); + } + case FLOAT: { + float f = 0.0F; + if (value != null) { + f = (float) value; + } + return ExpressionCreator.fp32(nullable, f); + } + case DOUBLE: { + double d = 0.0; + if (value != null) { + d = (float) value; + } + return ExpressionCreator.fp64(nullable, d); + } + case DATE: { + int days = 0; + if (value != null) { + Object o = value; + if (o instanceof Date || o instanceof LocalDate) { + days = DateTimeUtils$.MODULE$.anyToDays(o); + } else { + LOG.info("Date filter push down not supported"); + return null; + } + } + return ExpressionCreator.date(nullable, days); + } + case TIMESTAMP_WITHOUT_TIME_ZONE: + case TIMESTAMP_WITH_TIME_ZONE: + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: { + long micros = 0; + if (value != null) { + if (value instanceof Timestamp || value instanceof Instant) { + micros = DateTimeUtils$.MODULE$.anyToMicros(value); + } else { + LOG.info("Timestamp filter push down not supported"); + return null; + } + } + return ExpressionCreator.timestamp(nullable, micros); + } + default: + LOG.info("Filter push down not supported"); + break; + } + return null; + } + + @Override + protected Expression.Literal defaultMethod(org.apache.flink.table.expressions.Expression expression) { + return null; + } + +} + +class FieldRefVisitor extends ExpressionDefaultVisitor { + public FieldRefVisitor(Schema arrow_schema) { + this.arrow_schema = arrow_schema; + } + + private Schema arrow_schema; + + private static final Logger LOG = LoggerFactory.getLogger(FieldRefVisitor.class); + + public FieldReference visit(FieldReferenceExpression fieldReference) { + // only care about the last name + // may fail? + while (!fieldReference.getChildren().isEmpty()) { + fieldReference = (FieldReferenceExpression) fieldReference.getChildren().get(0); + } + LogicalType logicalType = fieldReference.getOutputDataType().getLogicalType(); + LogicalTypeRoot typeRoot = logicalType.getTypeRoot(); + Type type = mapType(typeRoot, logicalType.isNullable()); + if (type == null) { + // not supported + return null; + } + String name = fieldReference.getName(); + int idx = 0; + List fields = arrow_schema.getFields(); + // find idx + for (int i = 0; i < fields.size(); i++) { + if (fields.get(i).getName().equals(name)) { + idx = i; + break; + } + } + + return FieldReference.builder() + .type(Objects.requireNonNull(type)) + .addSegments( + ImmutableStructField.builder() + .offset(idx) + .build() + ) + .build(); + } + + @Override + protected FieldReference defaultMethod(org.apache.flink.table.expressions.Expression expression) { + return null; + } + + private Type mapType(LogicalTypeRoot typeRoot, Boolean nullable) { + TypeCreator R = TypeCreator.of(nullable); + switch (typeRoot) { + case CHAR: + case VARCHAR: { + // datafusion only support STRING + return R.STRING; + } + case BOOLEAN: { + return R.BOOLEAN; + } + case BINARY: + case VARBINARY: { + return R.BINARY; + } + case TINYINT: + case SMALLINT: + case INTEGER: { + return R.I32; + } + case BIGINT: { + return R.I64; + } + case FLOAT: { + return R.FP32; + } + case DOUBLE: { + return R.FP64; + } + case DATE: { + return R.DATE; + } + case TIMESTAMP_WITHOUT_TIME_ZONE: + case TIMESTAMP_WITH_TIME_ZONE: + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: { + return R.TIMESTAMP; + } + default: + LOG.info("unsupported type"); + } + return null; + } + + +} + +class CallExprVisitor extends ExpressionDefaultVisitor { + + public CallExprVisitor(Schema arrowSchema) { + this.arrowSchema = arrowSchema; + } + + private Schema arrowSchema; + private static final Logger LOG = LoggerFactory.getLogger(CallExprVisitor.class); + private static final ImmutableMap> + FILTERS = + new ImmutableMap.Builder< + FunctionDefinition, BiFunction>() + .put(BuiltInFunctionDefinitions.IS_NULL, (call, schema) -> makeUnaryFunction(call, schema, "is_null:any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.IS_NOT_NULL, (call,schema) -> makeUnaryFunction(call,schema, "is_not_null:any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.NOT, (call,schema) -> makeUnaryFunction(call,schema, "not:bool", SubstraitUtil.BooleanNamespace)) + .put(BuiltInFunctionDefinitions.OR, (call,schema) -> makeBinaryFunction(call,schema, "or:bool", SubstraitUtil.BooleanNamespace)) + .put(BuiltInFunctionDefinitions.AND, (call,schema) -> makeBinaryFunction(call,schema, "and:bool", SubstraitUtil.BooleanNamespace)) + .put(BuiltInFunctionDefinitions.EQUALS, (call,schema) -> makeBinaryFunction(call,schema, "equal:any_any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.NOT_EQUALS, (call,schema) -> makeBinaryFunction(call,schema, "not_equal:any_any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.GREATER_THAN, (call,schema) -> makeBinaryFunction(call,schema, "gt:any_any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL, (call,schema) -> makeBinaryFunction(call,schema, "gte:any_any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.LESS_THAN, (call,schema) -> makeBinaryFunction(call,schema, "lt:any_any", SubstraitUtil.CompNamespace)) + .put(BuiltInFunctionDefinitions.LESS_THAN_OR_EQUAL, (call,schema)-> makeBinaryFunction(call,schema, "lte:any_any", SubstraitUtil.CompNamespace)) + .build(); + + @Override + public Expression visit(CallExpression call) { + if (FILTERS.get(call.getFunctionDefinition()) == null) { + // unsupported predicate + LOG.info( + "Unsupported predicate [{}] cannot be pushed into native io.", + call); + return null; + } + return FILTERS.get(call.getFunctionDefinition()).apply(call, this.arrowSchema); + } + + static Expression makeBinaryFunction(CallExpression call, Schema arrow_schema, String funcKey, String namespace) { + List children = call.getChildren(); + assert children.size() == 2; + SubstraitVisitor visitor = new SubstraitVisitor(arrow_schema); + Expression left = children.get(0).accept(visitor); + Expression right = children.get(1).accept(visitor); + if (left == null || right == null) { + return null; + } + SimpleExtension.ScalarFunctionVariant func = SubstraitUtil.Se.getScalarFunction(SimpleExtension.FunctionAnchor.of(namespace, funcKey)); + List args = new ArrayList<>(); + args.add(left); + args.add(right); + return ExpressionCreator.scalarFunction(func, TypeCreator.NULLABLE.BOOLEAN, args); + } + + static Expression makeUnaryFunction(CallExpression call, Schema arrow_schema, String funcKey, String namespace) { + List children = call.getChildren(); + assert children.size() == 1; + SubstraitVisitor visitor = new SubstraitVisitor(arrow_schema); + Expression child = children.get(0).accept(visitor); + if (child == null) { + return null; + } + SimpleExtension.ScalarFunctionVariant func = SubstraitUtil.Se.getScalarFunction(SimpleExtension.FunctionAnchor.of(namespace, funcKey)); + List args = new ArrayList<>(); + args.add(child); + return ExpressionCreator.scalarFunction(func, TypeCreator.NULLABLE.BOOLEAN, args); + } + + @Override + protected Expression defaultMethod(org.apache.flink.table.expressions.Expression expression) { + return null; + } +} diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/table/LakeSoulTableSource.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/table/LakeSoulTableSource.java index 4e79afe84..73646e238 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/table/LakeSoulTableSource.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/table/LakeSoulTableSource.java @@ -9,9 +9,9 @@ import com.dmetasoul.lakesoul.meta.DBUtil; import com.dmetasoul.lakesoul.meta.entity.PartitionInfo; import com.dmetasoul.lakesoul.meta.entity.TableInfo; +import io.substrait.proto.Plan; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.lakesoul.source.LakeSoulSource; -import org.apache.flink.lakesoul.source.ParquetFilters; import org.apache.flink.lakesoul.tool.LakeSoulSinkOptions; import org.apache.flink.lakesoul.types.TableId; import org.apache.flink.table.connector.ChangelogMode; @@ -28,11 +28,13 @@ import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.VarCharType; import org.apache.flink.types.RowKind; +import org.apache.flink.lakesoul.substrait.SubstraitFlinkUtil; import org.apache.parquet.filter2.predicate.FilterPredicate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nullable; +import java.io.IOException; import java.util.*; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -60,7 +62,10 @@ public class LakeSoulTableSource protected List> remainingPartitions; - protected FilterPredicate filter; + // TODO remove this , now used for debug + protected FilterPredicate _filterPredicate; + // TODO merge + protected io.substrait.proto.Plan filter; public LakeSoulTableSource(TableId tableId, RowType rowType, @@ -83,6 +88,7 @@ public DynamicTableSource copy() { this.optionParams); lsts.projectedFields = this.projectedFields; lsts.remainingPartitions = this.remainingPartitions; + lsts._filterPredicate = this._filterPredicate; lsts.filter = this.filter; return lsts; } @@ -104,19 +110,28 @@ public Result applyFilters(List filters) { DBUtil.TablePartitionKeys partitionKeys = DBUtil.parseTableInfoPartitions(tableInfo.getPartitions()); Set partitionCols = new HashSet<>(partitionKeys.rangeKeys); for (ResolvedExpression filter : filters) { - if (ParquetFilters.filterContainsPartitionColumn(filter, partitionCols)) { + if (SubstraitFlinkUtil.filterContainsPartitionColumn(filter, partitionCols)) { remainingFilters.add(filter); } else { nonPartitionFilters.add(filter); } } // find acceptable non partition filters - Tuple2 filterPushDownResult = ParquetFilters.toParquetFilter(nonPartitionFilters, - remainingFilters); +// Tuple2 filterPushDownRes = ParquetFilters.toParquetFilter(nonPartitionFilters, +// remainingFilters); + Tuple2 filterPushDownResult = null; + try { + filterPushDownResult = SubstraitFlinkUtil.flinkExprToSubStraitPlan(nonPartitionFilters, + remainingFilters, tableInfo.getTableName(), tableInfo.getTableSchema()); + } catch (IOException e) { + throw new RuntimeException(e); + } this.filter = filterPushDownResult.f1; +// this.filterStr = filterPushDownRes.f1; LOG.info("Applied filters to native io: {}, accepted {}, remaining {}", this.filter, filterPushDownResult.f0.getAcceptedFilters(), filterPushDownResult.f0.getRemainingFilters()); +// LOG.info("FilterPlan: {}", this.filterPlan); return filterPushDownResult.f0; } @@ -215,6 +230,7 @@ public ScanRuntimeProvider getScanRuntimeProvider(ScanContext runtimeProviderCon this.pkColumns, this.optionParams, this.remainingPartitions, + this._filterPredicate, this.filter)); } diff --git a/native-io/lakesoul-io-java/pom.xml b/native-io/lakesoul-io-java/pom.xml index d4b048c3f..a8df9261e 100644 --- a/native-io/lakesoul-io-java/pom.xml +++ b/native-io/lakesoul-io-java/pom.xml @@ -26,6 +26,8 @@ SPDX-License-Identifier: Apache-2.0 8 12.0.0 3.1.0 + 0.28.0 + 3.22.0 @@ -84,6 +86,26 @@ SPDX-License-Identifier: Apache-2.0 2.2.16 + + io.substrait + core + + + org.slf4j + slf4j-jdk14 + + + ${substrait.version} + compile + + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + org.apache.spark spark-catalyst_${scala.binary.version} diff --git a/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/NativeIOReader.java b/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/NativeIOReader.java index e778e2528..6030a34a8 100644 --- a/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/NativeIOReader.java +++ b/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/NativeIOReader.java @@ -4,8 +4,9 @@ package com.dmetasoul.lakesoul.lakesoul.io; -import com.dmetasoul.lakesoul.lakesoul.io.jnr.LibLakeSoulIO; +import io.substrait.proto.Plan; import jnr.ffi.Pointer; +import jnr.ffi.Runtime; import jnr.ffi.byref.IntByReference; import org.apache.arrow.c.ArrowSchema; import org.apache.arrow.c.Data; @@ -34,6 +35,18 @@ public void addFilter(String filter) { ioConfigBuilder = libLakeSoulIO.lakesoul_config_builder_add_filter(ioConfigBuilder, filter); } + /** + * usually use only once + * + * @param plan Filter{} + */ + public void addFilterProto(Plan plan) { + byte[] bytes = plan.toByteArray(); + Pointer buf = Runtime.getRuntime(libLakeSoulIO).getMemoryManager().allocateDirect(bytes.length); + buf.put(0, bytes, 0, bytes.length); + ioConfigBuilder = libLakeSoulIO.lakesoul_config_builder_add_filter_proto(ioConfigBuilder, buf.address(), bytes.length); + } + public void addMergeOps(Map mergeOps) { for (Map.Entry entry : mergeOps.entrySet()) { ioConfigBuilder = libLakeSoulIO.lakesoul_config_builder_add_merge_op(ioConfigBuilder, entry.getKey(), entry.getValue()); diff --git a/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/jnr/LibLakeSoulIO.java b/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/jnr/LibLakeSoulIO.java index 2143679fb..894c39025 100644 --- a/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/jnr/LibLakeSoulIO.java +++ b/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/jnr/LibLakeSoulIO.java @@ -32,6 +32,8 @@ public interface LibLakeSoulIO { Pointer lakesoul_config_builder_add_filter(Pointer builder, String filter); + Pointer lakesoul_config_builder_add_filter_proto(Pointer builder, @LongLong long proto_addr, int len); + Pointer lakesoul_config_builder_add_merge_op(Pointer builder, String field, String mergeOp); Pointer lakesoul_config_builder_set_schema(Pointer builder, @LongLong long schemaAddr); diff --git a/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/substrait/SubstraitUtil.java b/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/substrait/SubstraitUtil.java new file mode 100644 index 000000000..1f8adb6f8 --- /dev/null +++ b/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/substrait/SubstraitUtil.java @@ -0,0 +1,192 @@ +package com.dmetasoul.lakesoul.lakesoul.io.substrait; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; + +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.extension.SimpleExtension; +import io.substrait.plan.Plan; +import io.substrait.plan.PlanProtoConverter; +import io.substrait.relation.NamedScan; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class SubstraitUtil { + public static final SimpleExtension.ExtensionCollection Se; + public static final SubstraitBuilder Builder; + + public static final String CompNamespace = "/functions_comparison.yaml"; + public static final String BooleanNamespace = "/functions_boolean.yaml"; + + static { + try { + Se = SimpleExtension.loadDefaults(); + Builder = new SubstraitBuilder(Se); + } catch (IOException e) { + throw new RuntimeException("load simple extension failed"); + } + } + + + + public static Plan exprToFilter(Expression e, String tableName, Schema arrowSchema) { + if (e == null) { + return null; + } + List tableNames = Stream.of(tableName).collect(Collectors.toList()); + List columnNames = new ArrayList<>(); + List fields = arrowSchema.getFields(); + List columnTypes = new ArrayList<>(); + for (Field field : fields) { + Type type = fromArrowType(field.getType(), field.isNullable()); + if (type == null) { + return null; + } + columnTypes.add(type); + String name = field.getName(); + columnNames.add(name); + } + NamedScan namedScan = Builder.namedScan(tableNames, columnNames, columnTypes); + namedScan = + NamedScan.builder() + .from(namedScan) + .filter(e) + .build(); + + + Plan.Root root = Builder.root(namedScan); + return Builder.plan(root); + } + + + public static io.substrait.proto.Expression exprToProto(Expression expr) { + ExpressionProtoConverter converter = new ExpressionProtoConverter(null, null); + return expr.accept(converter); + } + + + public static io.substrait.proto.Plan planToProto(Plan plan) { + if (plan == null) { + return null; + } + return new PlanProtoConverter().toProto(plan); + } + + + + public static Type fromArrowType(ArrowType arrowType, boolean nullable) { + TypeCreator R = TypeCreator.of(nullable); + switch (arrowType.getTypeID()) { + case Null: + break; + case Struct: + break; + case List: + break; + case LargeList: + break; + case FixedSizeList: + break; + case Union: + break; + case Map: + break; + case Int: { + ArrowType.Int intType = (ArrowType.Int) arrowType; + if (intType.getIsSigned()) { + if (intType.getBitWidth() == 8) { + return R.I8; + } else if (intType.getBitWidth() == 16) { + return R.I16; + } else if (intType.getBitWidth() == 32) { + return R.I32; + } else if (intType.getBitWidth() == 64) { + return R.I64; + } + } + break; + } + case FloatingPoint: { + ArrowType.FloatingPoint fpType = (ArrowType.FloatingPoint) arrowType; + if (fpType.getPrecision() == FloatingPointPrecision.SINGLE) { + return R.FP32; + } else if (fpType.getPrecision() == FloatingPointPrecision.DOUBLE) { + return R.FP64; + } + break; + } + case Utf8: { + return R.STRING; + } + case LargeUtf8: + break; + case Binary: { + return R.BINARY; + } + case LargeBinary: + break; + case FixedSizeBinary: + break; + case Bool: { + return R.BOOLEAN; + } + case Decimal: { + ArrowType.Decimal decimalType = (ArrowType.Decimal) arrowType; + return R.decimal(decimalType.getPrecision(), decimalType.getScale()); + } + case Date: { + ArrowType.Date dateType = (ArrowType.Date) arrowType; + if (dateType.getUnit() == DateUnit.DAY) { + return R.DATE; + } + break; + } + case Time: + break; + case Timestamp: { + ArrowType.Timestamp tsType = (ArrowType.Timestamp) arrowType; + if (tsType.getUnit() == TimeUnit.MICROSECOND) { + if (tsType.getTimezone() != null) { + return R.TIMESTAMP_TZ; + } + return R.TIMESTAMP; + } + break; + } + case Interval: { + ArrowType.Interval intervalType = (ArrowType.Interval) arrowType; + if (intervalType.getUnit() == IntervalUnit.YEAR_MONTH) { + return R.INTERVAL_YEAR; + } + break; + } + case Duration: { + ArrowType.Duration durationType = (ArrowType.Duration) arrowType; + if (durationType.getUnit() == TimeUnit.MICROSECOND) { + return R.INTERVAL_DAY; + } + break; + } + case NONE: + break; + } + // not supported + return null; + } +} + diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 19dca0241..f405c9bb6 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -404,6 +404,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-recursion" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "async-task" version = "4.7.0" @@ -1083,6 +1094,22 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "datafusion-substrait" +version = "33.0.0" +source = "git+https://github.com/lakesoul-io/arrow-datafusion.git?branch=datafusion-33-parquet-prefetch#235eb27b6b0d23b18fb4a111fecbf5fa1b0d46a2" +dependencies = [ + "async-recursion", + "chrono", + "datafusion", + "itertools", + "object_store", + "prost", + "prost-types", + "substrait", + "tokio", +] + [[package]] name = "derivative" version = "2.2.0" @@ -1111,6 +1138,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +[[package]] +name = "dyn-clone" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" + [[package]] name = "either" version = "1.9.0" @@ -1386,6 +1419,19 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "git2" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b3ba52851e73b46a4c3df1d89343741112003f0f6f13beb0dfac9e457c3fdcd" +dependencies = [ + "bitflags 2.4.1", + "libc", + "libgit2-sys", + "log", + "url", +] + [[package]] name = "glob" version = "0.3.1" @@ -1428,6 +1474,15 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.14.3" @@ -1737,6 +1792,7 @@ dependencies = [ "dary_heap", "datafusion", "datafusion-common", + "datafusion-substrait", "derivative", "futures", "half", @@ -1746,6 +1802,7 @@ dependencies = [ "object_store", "parking_lot", "parquet", + "prost", "proto", "rand", "serde", @@ -1766,7 +1823,10 @@ version = "2.5.0" dependencies = [ "arrow", "cbindgen", + "datafusion-substrait", "lakesoul-io", + "log", + "prost", "serde", "serde_json", "tokio", @@ -1886,12 +1946,36 @@ version = "0.2.152" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" +[[package]] +name = "libgit2-sys" +version = "0.16.2+1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4126d8b4ee5c9d9ea891dd875cfdc1e9d0950437179104b183d7d8a74d24e8" +dependencies = [ + "cc", + "libc", + "libz-sys", + "pkg-config", +] + [[package]] name = "libm" version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +[[package]] +name = "libz-sys" +version = "1.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037731f5d3aaa87a5675e895b63ddff1a87624bc29f77004ea829809654e48f6" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.12" @@ -2620,6 +2704,16 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "regress" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ed9969cad8051328011596bf549629f1b800cf1731e7964b1eef8dfc480d2c2" +dependencies = [ + "hashbrown 0.13.2", + "memchr", +] + [[package]] name = "reqwest" version = "0.11.23" @@ -2771,6 +2865,30 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schemars" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45a28f4c49489add4ce10783f7911893516f15afe45d015608d41faca6bc4d29" +dependencies = [ + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c767fd6fa65d9ccf9cf026122c1b555f2ef9a4f0cea69da4d7dbc3e258d30967" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 1.0.109", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2819,6 +2937,17 @@ dependencies = [ "syn 2.0.52", ] +[[package]] +name = "serde_derive_internals" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85bf8229e7920a9f636479437026331ce11aa132b4dde37d121944a44d6e5f3c" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "serde_json" version = "1.0.114" @@ -2830,6 +2959,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_tokenstream" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a00ffd23fd882d096f09fcaae2a9de8329a328628e86027e049ee051dc1621f" +dependencies = [ + "proc-macro2", + "quote", + "serde", + "syn 2.0.52", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2842,6 +2983,19 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_yaml" +version = "0.9.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd075d994154d4a774f95b51fb96bdc2832b0ea48425c92546073816cda1f2f" +dependencies = [ + "indexmap 2.2.5", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "sha2" version = "0.10.8" @@ -3048,6 +3202,29 @@ dependencies = [ "syn 2.0.52", ] +[[package]] +name = "substrait" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7299fc531294d189834eeaf7928482f311c0ada2cf0007948989cf75d0228183" +dependencies = [ + "git2", + "heck", + "prettyplease", + "prost", + "prost-build", + "prost-types", + "protobuf-src", + "schemars", + "semver", + "serde", + "serde_json", + "serde_yaml", + "syn 2.0.52", + "typify", + "walkdir", +] + [[package]] name = "subtle" version = "2.5.0" @@ -3419,6 +3596,50 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "typify" +version = "0.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2e3b707a653e2915a2fc2c4ee96a3d30b9554b9435eb4cc8b5c6c74bbdd3044" +dependencies = [ + "typify-impl", + "typify-macro", +] + +[[package]] +name = "typify-impl" +version = "0.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d9c752192779f666e4c868672dee56a652b82c08032c7e9d23f6a845b282298" +dependencies = [ + "heck", + "log", + "proc-macro2", + "quote", + "regress", + "schemars", + "serde_json", + "syn 2.0.52", + "thiserror", + "unicode-ident", +] + +[[package]] +name = "typify-macro" +version = "0.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a14defd554507e72a2bb93cd081c8b374cfed43b3d986b141ad3839d9fd6986b" +dependencies = [ + "proc-macro2", + "quote", + "schemars", + "serde", + "serde_json", + "serde_tokenstream", + "syn 2.0.52", + "typify-impl", +] + [[package]] name = "unicode-bidi" version = "0.3.14" @@ -3452,6 +3673,12 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" +[[package]] +name = "unsafe-libyaml" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab4c90930b95a82d00dc9e9ac071b4991924390d46cbd0dfe566148667605e4b" + [[package]] name = "untrusted" version = "0.7.1" @@ -3509,6 +3736,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.4" diff --git a/rust/justfile b/rust/justfile new file mode 100644 index 000000000..e94a306af --- /dev/null +++ b/rust/justfile @@ -0,0 +1,4 @@ +copy-to-java: + cargo build + cp target/debug/liblakesoul_io_c.dylib ../lakesoul-common/target/classes/ + cp target/debug/liblakesoul_metadata_c.dylib ../lakesoul-common/target/classes/ diff --git a/rust/lakesoul-datafusion/src/catalog/mod.rs b/rust/lakesoul-datafusion/src/catalog/mod.rs index 435cedce4..afdae5322 100644 --- a/rust/lakesoul-datafusion/src/catalog/mod.rs +++ b/rust/lakesoul-datafusion/src/catalog/mod.rs @@ -81,9 +81,7 @@ pub(crate) async fn create_io_config_builder( if let Some(table_name) = table_name { let table_info = client.get_table_info_by_table_name(table_name, namespace).await?; let data_files = if fetch_files { - client - .get_data_files_by_table_name(table_name, namespace) - .await? + client.get_data_files_by_table_name(table_name, namespace).await? } else { vec![] }; diff --git a/rust/lakesoul-datafusion/src/datasource/file_format/metadata_format.rs b/rust/lakesoul-datafusion/src/datasource/file_format/metadata_format.rs index 109050467..128418118 100644 --- a/rust/lakesoul-datafusion/src/datasource/file_format/metadata_format.rs +++ b/rust/lakesoul-datafusion/src/datasource/file_format/metadata_format.rs @@ -48,7 +48,10 @@ use tokio::task::JoinHandle; use tracing::debug; use crate::catalog::{commit_data, parse_table_info_partitions}; -use crate::lakesoul_table::helpers::{columnar_values_to_partition_desc, columnar_values_to_sub_path, create_io_config_builder_from_table_info, get_columnar_values}; +use crate::lakesoul_table::helpers::{ + columnar_values_to_partition_desc, columnar_values_to_sub_path, create_io_config_builder_from_table_info, + get_columnar_values, +}; pub struct LakeSoulMetaDataParquetFormat { parquet_format: Arc, @@ -64,7 +67,12 @@ impl Debug for LakeSoulMetaDataParquetFormat { } impl LakeSoulMetaDataParquetFormat { - pub async fn new(client: MetaDataClientRef, parquet_format: Arc, table_info: Arc, conf: LakeSoulIOConfig) -> crate::error::Result { + pub async fn new( + client: MetaDataClientRef, + parquet_format: Arc, + table_info: Arc, + conf: LakeSoulIOConfig, + ) -> crate::error::Result { Ok(Self { parquet_format, client, @@ -129,36 +137,49 @@ impl FileFormat for LakeSoulMetaDataParquetFormat { for field in &conf.table_partition_cols { builder.push(Field::new(field.name(), field.data_type().clone(), false)); } - + let table_schema = Arc::new(builder.finish()); - + let projection = conf.projection.clone(); let target_schema = project_schema(&table_schema, projection.as_ref())?; - let merged_projection = compute_project_column_indices(table_schema.clone(), target_schema.clone(), self.conf.primary_keys_slice()); + let merged_projection = compute_project_column_indices( + table_schema.clone(), + target_schema.clone(), + self.conf.primary_keys_slice(), + ); let merged_schema = project_schema(&table_schema, merged_projection.as_ref())?; // files to read - let flatten_conf = - flatten_file_scan_config(state, self.parquet_format.clone(), conf, self.conf.primary_keys_slice(), target_schema.clone()).await?; - + let flatten_conf = flatten_file_scan_config( + state, + self.parquet_format.clone(), + conf, + self.conf.primary_keys_slice(), + target_schema.clone(), + ) + .await?; - let mut inputs_map: HashMap>, Vec>) > = HashMap::new(); + let mut inputs_map: HashMap>, Vec>)> = + HashMap::new(); let mut column_nullable = HashSet::::new(); for config in &flatten_conf { - let (partition_desc, partition_columnar_value) = partition_desc_from_file_scan_config(&config)?; + let (partition_desc, partition_columnar_value) = partition_desc_from_file_scan_config(config)?; let partition_columnar_value = Arc::new(partition_columnar_value); - let parquet_exec = Arc::new(ParquetExec::new(config.clone(), predicate.clone(), self.parquet_format.metadata_size_hint(state.config_options()))); + let parquet_exec = Arc::new(ParquetExec::new( + config.clone(), + predicate.clone(), + self.parquet_format.metadata_size_hint(state.config_options()), + )); for field in parquet_exec.schema().fields().iter() { if field.is_nullable() { column_nullable.insert(field.name().clone()); } } - if let Some((_, inputs)) = inputs_map.get_mut(&partition_desc) - { + if let Some((_, inputs)) = inputs_map.get_mut(&partition_desc) { inputs.push(parquet_exec); } else { inputs_map.insert( @@ -168,21 +189,19 @@ impl FileFormat for LakeSoulMetaDataParquetFormat { } } - let merged_schema = SchemaRef::new( - Schema::new( - merged_schema - .fields() - .iter() - .map(|field| { - Field::new( - field.name(), - field.data_type().clone(), - field.is_nullable() | column_nullable.contains(field.name()) - ) - }) - .collect::>() - ) - ); + let merged_schema = SchemaRef::new(Schema::new( + merged_schema + .fields() + .iter() + .map(|field| { + Field::new( + field.name(), + field.data_type().clone(), + field.is_nullable() | column_nullable.contains(field.name()), + ) + }) + .collect::>(), + )); let mut partitioned_exec = Vec::new(); for (_, (partition_columnar_values, inputs)) in inputs_map { @@ -212,7 +231,6 @@ impl FileFormat for LakeSoulMetaDataParquetFormat { } else { Ok(exec) } - } async fn create_writer_physical_plan( @@ -273,7 +291,8 @@ impl LakeSoulHashSinkExec { table_info: Arc, metadata_client: MetaDataClientRef, ) -> Result { - let (range_partitions, _) = parse_table_info_partitions(table_info.partitions.clone()).map_err(|_| DataFusionError::External("parse table_info.partitions failed".into()))?; + let (range_partitions, _) = parse_table_info_partitions(table_info.partitions.clone()) + .map_err(|_| DataFusionError::External("parse table_info.partitions failed".into()))?; let range_partitions = Arc::new(range_partitions); Ok(Self { input, @@ -313,17 +332,16 @@ impl LakeSoulHashSinkExec { partitioned_file_path_and_row_count: Arc, u64)>>>, ) -> Result { let mut data = input.execute(partition, context.clone())?; - let schema_projection_excluding_range = - data.schema() - .fields() - .iter() - .enumerate() - .filter_map(|(idx, field)| - match range_partitions.contains(field.name()) { - true => None, - false => Some(idx) - }) - .collect::>(); + let schema_projection_excluding_range = data + .schema() + .fields() + .iter() + .enumerate() + .filter_map(|(idx, field)| match range_partitions.contains(field.name()) { + true => None, + false => Some(idx), + }) + .collect::>(); let mut row_count = 0; // let mut async_writer = MultiPartAsyncWriter::try_new(lakesoul_io_config).await?; @@ -334,7 +352,13 @@ impl LakeSoulHashSinkExec { let columnar_values = get_columnar_values(&batch, range_partitions.clone())?; let partition_desc = columnar_values_to_partition_desc(&columnar_values); let batch_excluding_range = batch.project(&schema_projection_excluding_range)?; - let file_absolute_path = format!("{}{}part-{}_{:0>4}.parquet", table_info.table_path, columnar_values_to_sub_path(&columnar_values), write_id, partition); + let file_absolute_path = format!( + "{}{}part-{}_{:0>4}.parquet", + table_info.table_path, + columnar_values_to_sub_path(&columnar_values), + write_id, + partition + ); if !partitioned_writer.contains_key(&partition_desc) { let mut config = create_io_config_builder_from_table_info(table_info.clone()) @@ -357,16 +381,12 @@ impl LakeSoulHashSinkExec { for (partition_desc, writer) in partitioned_writer.into_iter() { let file_absolute_path = writer.absolute_path(); let num_rows = writer.nun_rows(); - if let Some(file_path_and_row_count) = - partitioned_file_path_and_row_count_locked.get_mut(&partition_desc) - { + if let Some(file_path_and_row_count) = partitioned_file_path_and_row_count_locked.get_mut(&partition_desc) { file_path_and_row_count.0.push(file_absolute_path); file_path_and_row_count.1 += num_rows; } else { - partitioned_file_path_and_row_count_locked.insert( - partition_desc.clone(), - (vec![file_absolute_path], num_rows), - ); + partitioned_file_path_and_row_count_locked + .insert(partition_desc.clone(), (vec![file_absolute_path], num_rows)); } writer.flush_and_close().await?; } @@ -497,11 +517,7 @@ impl ExecutionPlan for LakeSoulHashSinkExec { let write_id = rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); - let partitioned_file_path_and_row_count = - Arc::new( - Mutex::new( - HashMap::, u64)>::new() - )); + let partitioned_file_path_and_row_count = Arc::new(Mutex::new(HashMap::, u64)>::new())); for i in 0..num_input_partitions { let sink_task = tokio::spawn(Self::pull_and_sink( self.input().clone(), diff --git a/rust/lakesoul-datafusion/src/datasource/table_provider.rs b/rust/lakesoul-datafusion/src/datasource/table_provider.rs index 5456f37ce..a8afb30f0 100644 --- a/rust/lakesoul-datafusion/src/datasource/table_provider.rs +++ b/rust/lakesoul-datafusion/src/datasource/table_provider.rs @@ -31,7 +31,6 @@ use datafusion::{execution::context::SessionState, logical_expr::Expr}; use futures::stream::FuturesUnordered; use futures::StreamExt; - use lakesoul_io::helpers::listing_table_from_lakesoul_io_config; use lakesoul_io::lakesoul_io_config::LakeSoulIOConfig; use lakesoul_metadata::MetaDataClientRef; @@ -81,23 +80,28 @@ impl LakeSoulTableProvider { for (idx, field) in table_schema.fields().iter().enumerate() { match range_partitions.contains(field.name()) { false => file_schema_projection.push(idx), - true => range_partition_projection.push(idx) + true => range_partition_projection.push(idx), }; } - + let file_schema = Arc::new(table_schema.project(&file_schema_projection)?); - let table_schema = Arc::new(table_schema.project(&[file_schema_projection, range_partition_projection].concat())?); + let table_schema = + Arc::new(table_schema.project(&[file_schema_projection, range_partition_projection].concat())?); - let file_format: Arc = - Arc::new(LakeSoulMetaDataParquetFormat::new( + let file_format: Arc = Arc::new( + LakeSoulMetaDataParquetFormat::new( client.clone(), - Arc::new(ParquetFormat::new()), - table_info.clone(), - lakesoul_io_config.clone() - ).await?); + Arc::new(ParquetFormat::new()), + table_info.clone(), + lakesoul_io_config.clone(), + ) + .await?, + ); + + let (_, listing_table) = + listing_table_from_lakesoul_io_config(session_state, lakesoul_io_config.clone(), file_format, as_sink) + .await?; - let (_, listing_table) = listing_table_from_lakesoul_io_config(session_state, lakesoul_io_config.clone(), file_format, as_sink).await?; - Ok(Self { listing_table, client, @@ -133,12 +137,9 @@ impl LakeSoulTableProvider { &self.table_info.table_id } - fn is_partition_filter(&self, f: &Expr) -> bool { if let Ok(cols) = f.to_columns() { - cols - .iter() - .all(|col| self.range_partitions.contains(&col.name)) + cols.iter().all(|col| self.range_partitions.contains(&col.name)) } else { false } @@ -156,7 +157,7 @@ impl LakeSoulTableProvider { self.file_schema.clone() } - pub fn table_partition_cols(&self) -> &[(String, DataType)]{ + pub fn table_partition_cols(&self) -> &[(String, DataType)] { &self.options().table_partition_cols } @@ -180,15 +181,16 @@ impl LakeSoulTableProvider { }, }) } else { - return Err(DataFusionError::Plan( + Err(DataFusionError::Plan( // Return an error if schema of the input query does not match with the table schema. - format!("Expected single column references in output_ordering, got {}", expr) - )); + format!("Expected single column references in output_ordering, got {}", expr), + )) } } else { - return Err(DataFusionError::Plan( - format!("Expected Expr::Sort in output_ordering, but got {}", expr) - )); + Err(DataFusionError::Plan(format!( + "Expected Expr::Sort in output_ordering, but got {}", + expr + ))) } }) .collect::>>()?; @@ -196,8 +198,6 @@ impl LakeSoulTableProvider { } Ok(all_sort_orders) } - - async fn list_files_for_scan<'a>( &'a self, @@ -211,16 +211,28 @@ impl LakeSoulTableProvider { return Ok((vec![], Statistics::new_unknown(&self.file_schema()))); }; - let all_partition_info = self.client - .get_all_partition_info(self.table_id()) + let all_partition_info = self.client.get_all_partition_info(self.table_id()).await.map_err(|_| { + DataFusionError::External( + format!( + "get all partition_info of table {} failed", + &self.table_info().table_name + ) + .into(), + ) + })?; + + let prune_partition_info = prune_partitions(all_partition_info, filters, self.table_partition_cols()) .await - .map_err(|_| DataFusionError::External(format!("get all partition_info of table {} failed", &self.table_info().table_name).into()))?; + .map_err(|_| { + DataFusionError::External( + format!( + "get all partition_info of table {} failed", + &self.table_info().table_name + ) + .into(), + ) + })?; - let prune_partition_info = - prune_partitions(all_partition_info, filters, self.table_partition_cols()) - .await - .map_err(|_| DataFusionError::External(format!("get all partition_info of table {} failed", &self.table_info().table_name).into()))?; - let mut futures = FuturesUnordered::new(); for partition in prune_partition_info { futures.push(listing_partition_info(partition, store.as_ref(), self.client())) @@ -236,28 +248,23 @@ impl LakeSoulTableProvider { .into_iter() .flatten() .zip(self.table_partition_cols()) - .map(|(parsed, (_, datatype))| { - ScalarValue::try_from_string(parsed.to_string(), datatype) - }) + .map(|(parsed, (_, datatype))| ScalarValue::try_from_string(parsed.to_string(), datatype)) .collect::>>()?; let files = object_metas .into_iter() - .map(|object_meta| - PartitionedFile { - object_meta, - partition_values: partition_values.clone(), - range: None, - extensions: None, - } - ) + .map(|object_meta| PartitionedFile { + object_meta, + partition_values: partition_values.clone(), + range: None, + extensions: None, + }) .collect::>(); file_groups.push(files) } Ok((file_groups, Statistics::new_unknown(self.schema().deref()))) } - } #[async_trait] @@ -281,8 +288,7 @@ impl TableProvider for LakeSoulTableProvider { filters: &[Expr], limit: Option, ) -> Result> { - let (partitioned_file_lists, _) = - self.list_files_for_scan(state, filters, limit).await?; + let (partitioned_file_lists, _) = self.list_files_for_scan(state, filters, limit).await?; // if no files need to be read, return an `EmptyExec` if partitioned_file_lists.is_empty() { @@ -299,22 +305,17 @@ impl TableProvider for LakeSoulTableProvider { .iter() .map(|col| Ok(self.schema().field_with_name(&col.0)?.clone())) .collect::>>()?; - + let filters = if let Some(expr) = conjunction(filters.to_vec()) { // NOTE: Use the table schema (NOT file schema) here because `expr` may contain references to partition columns. let table_df_schema = self.schema().as_ref().clone().to_dfschema()?; - let filters = create_physical_expr( - &expr, - &table_df_schema, - &self.schema(), - state.execution_props(), - )?; + let filters = create_physical_expr(&expr, &table_df_schema, &self.schema(), state.execution_props())?; Some(filters) } else { None }; - let object_store_url = if let Some(url) = self.listing_table.table_paths().get(0) { + let object_store_url = if let Some(url) = self.listing_table.table_paths().first() { url.object_store() } else { return Ok(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))); @@ -340,7 +341,6 @@ impl TableProvider for LakeSoulTableProvider { filters.as_ref(), ) .await - } fn supports_filters_pushdown(&self, filters: &[&Expr]) -> Result> { @@ -354,7 +354,6 @@ impl TableProvider for LakeSoulTableProvider { } }) .collect() - } async fn insert_into( @@ -363,7 +362,6 @@ impl TableProvider for LakeSoulTableProvider { input: Arc, overwrite: bool, ) -> Result> { - let table_path = &self.listing_table.table_paths()[0]; // Get the object store for the table path. let _store = state.runtime_env().object_store(table_path)?; diff --git a/rust/lakesoul-datafusion/src/lakesoul_table/helpers.rs b/rust/lakesoul-datafusion/src/lakesoul_table/helpers.rs index 2e049b189..b1004e374 100644 --- a/rust/lakesoul-datafusion/src/lakesoul_table/helpers.rs +++ b/rust/lakesoul-datafusion/src/lakesoul_table/helpers.rs @@ -4,11 +4,23 @@ use std::sync::Arc; -use arrow::{array::{Array, ArrayRef, AsArray, StringBuilder}, compute::prep_null_mask_filter, datatypes::{DataType, Field, Fields, Schema}, record_batch::RecordBatch}; -use arrow_cast::cast; +use arrow::{ + array::{Array, ArrayRef, AsArray, StringBuilder}, + compute::prep_null_mask_filter, + datatypes::{DataType, Field, Fields, Schema}, + record_batch::RecordBatch, +}; use arrow_arith::boolean::and; +use arrow_cast::cast; -use datafusion::{common::{DFField, DFSchema}, error::DataFusionError, execution::context::ExecutionProps, logical_expr::Expr, physical_expr::create_physical_expr, scalar::ScalarValue}; +use datafusion::{ + common::{DFField, DFSchema}, + error::DataFusionError, + execution::context::ExecutionProps, + logical_expr::Expr, + physical_expr::create_physical_expr, + scalar::ScalarValue, +}; use lakesoul_metadata::MetaDataClientRef; use object_store::{path::Path, ObjectMeta, ObjectStore}; use tracing::{debug, trace}; @@ -34,36 +46,41 @@ pub(crate) fn create_io_config_builder_from_table_info(table_info: Arc>) -> datafusion::error::Result> { +pub fn get_columnar_values( + batch: &RecordBatch, + range_partitions: Arc>, +) -> datafusion::error::Result> { range_partitions .iter() .map(|range_col| { - if let Some(array) = batch.column_by_name(&range_col) { + if let Some(array) = batch.column_by_name(range_col) { match ScalarValue::try_from_array(array, 0) { Ok(scalar) => Ok((range_col.clone(), scalar)), - Err(e) => Err(e) + Err(e) => Err(e), } } else { - Err(datafusion::error::DataFusionError::External(format!("").into())) + Err(datafusion::error::DataFusionError::External(String::new().into())) } }) .collect::>>() } -pub fn columnar_values_to_sub_path(columnar_values: &Vec<(String, ScalarValue)>) -> String { +pub fn columnar_values_to_sub_path(columnar_values: &[(String, ScalarValue)]) -> String { if columnar_values.is_empty() { "/".to_string() } else { - format!("/{}/", columnar_values - .iter() - .map(|(k, v)| format!("{}={}", k, v)) - .collect::>() - .join("/")) + format!( + "/{}/", + columnar_values + .iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join("/") + ) } } -pub fn columnar_values_to_partition_desc(columnar_values: &Vec<(String, ScalarValue)>) -> String { +pub fn columnar_values_to_partition_desc(columnar_values: &[(String, ScalarValue)]) -> String { if columnar_values.is_empty() { "-5".to_string() } else { @@ -130,20 +147,15 @@ pub async fn prune_partitions( // Applies `filter` to `batch` returning `None` on error let do_filter = |filter| -> Option { let expr = create_physical_expr(filter, &df_schema, &schema, &props).ok()?; - expr.evaluate(&batch) - .ok()? - .into_array(all_partition_info.len()) - .ok() + expr.evaluate(&batch).ok()?.into_array(all_partition_info.len()).ok() }; //.Compute the conjunction of the filters, ignoring errors - let mask = filters - .iter() - .fold(None, |acc, filter| match (acc, do_filter(filter)) { - (Some(a), Some(b)) => Some(and(&a, b.as_boolean()).unwrap_or(a)), - (None, Some(r)) => Some(r.as_boolean().clone()), - (r, None) => r, - }); + let mask = filters.iter().fold(None, |acc, filter| match (acc, do_filter(filter)) { + (Some(a), Some(b)) => Some(and(&a, b.as_boolean()).unwrap_or(a)), + (None, Some(r)) => Some(r.as_boolean().clone()), + (r, None) => r, + }); let mask = match mask { Some(mask) => mask, @@ -155,7 +167,7 @@ pub async fn prune_partitions( 0 => mask, _ => prep_null_mask_filter(&mask), }; - + // Sanity check assert_eq!(prepared.len(), all_partition_info.len()); @@ -164,45 +176,53 @@ pub async fn prune_partitions( .zip(prepared.values()) .filter_map(|(p, f)| f.then_some(p)) .collect(); - + Ok(filtered) } pub fn parse_partitions_for_partition_desc<'a, I>( partition_desc: &'a str, table_partition_cols: I, -) -> Option> +) -> Option> where I: IntoIterator, { let mut part_values = vec![]; - for (part, pn) in partition_desc.split(",").zip(table_partition_cols) { + for (part, pn) in partition_desc.split(',').zip(table_partition_cols) { match part.split_once('=') { Some((name, val)) if name == pn => part_values.push(val), _ => { debug!( "Ignoring file: partition_desc='{}', part='{}', partition_col='{}'", - partition_desc, - part, - pn, + partition_desc, part, pn, ); return None; } } } Some(part_values) - } - -pub async fn listing_partition_info(partition_info: PartitionInfo, store: &dyn ObjectStore, client: MetaDataClientRef) -> datafusion::error::Result<(PartitionInfo, Vec)> { +pub async fn listing_partition_info( + partition_info: PartitionInfo, + store: &dyn ObjectStore, + client: MetaDataClientRef, +) -> datafusion::error::Result<(PartitionInfo, Vec)> { trace!("Listing partition {:?}", partition_info); let paths = client - .get_data_files_of_single_partition(&partition_info).await.map_err(|_| DataFusionError::External("listing partition info failed".into()))?; + .get_data_files_of_single_partition(&partition_info) + .await + .map_err(|_| DataFusionError::External("listing partition info failed".into()))?; let mut files = Vec::new(); for path in paths { - let result = store.head(&Path::from_url_path(Url::parse(path.as_str()).map_err(|e| DataFusionError::External(Box::new(e)))?.path())?).await?; + let result = store + .head(&Path::from_url_path( + Url::parse(path.as_str()) + .map_err(|e| DataFusionError::External(Box::new(e)))? + .path(), + )?) + .await?; files.push(result); } Ok((partition_info, files)) -} \ No newline at end of file +} diff --git a/rust/lakesoul-datafusion/src/lakesoul_table/mod.rs b/rust/lakesoul-datafusion/src/lakesoul_table/mod.rs index 277e19d2e..8a8a1328a 100644 --- a/rust/lakesoul-datafusion/src/lakesoul_table/mod.rs +++ b/rust/lakesoul-datafusion/src/lakesoul_table/mod.rs @@ -6,7 +6,7 @@ pub mod helpers; use std::{ops::Deref, sync::Arc}; -use arrow::datatypes::{SchemaRef, Schema}; +use arrow::datatypes::{Schema, SchemaRef}; use arrow_cast::pretty::pretty_format_batches; use datafusion::sql::TableReference; use datafusion::{ @@ -131,7 +131,14 @@ impl LakeSoulTable { let config_builder = create_io_config_builder(self.client(), Some(self.table_name()), true, self.table_namespace()).await?; let provider = Arc::new( - LakeSoulTableProvider::try_new(&context.state(), self.client(), config_builder.build(), self.table_info(), false).await?, + LakeSoulTableProvider::try_new( + &context.state(), + self.client(), + config_builder.build(), + self.table_info(), + false, + ) + .await?, ); Ok(context.read_table(provider)?) } @@ -142,7 +149,14 @@ impl LakeSoulTable { .await? .with_prefix(self.table_info.table_path.clone()); Ok(Arc::new( - LakeSoulTableProvider::try_new(session_state, self.client(), config_builder.build(), self.table_info(), true).await?, + LakeSoulTableProvider::try_new( + session_state, + self.client(), + config_builder.build(), + self.table_info(), + true, + ) + .await?, )) } diff --git a/rust/lakesoul-datafusion/src/planner/physical_planner.rs b/rust/lakesoul-datafusion/src/planner/physical_planner.rs index a5a830574..5ad1568f4 100644 --- a/rust/lakesoul-datafusion/src/planner/physical_planner.rs +++ b/rust/lakesoul-datafusion/src/planner/physical_planner.rs @@ -19,7 +19,7 @@ use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; use async_trait::async_trait; use datafusion::logical_expr::{DmlStatement, WriteOp}; -use lakesoul_io::helpers::{column_names_to_physical_sort_expr, column_names_to_physical_expr}; +use lakesoul_io::helpers::{column_names_to_physical_expr, column_names_to_physical_sort_expr}; use lakesoul_io::repartition::RepartitionByRangeAndHashExec; use crate::lakesoul_table::LakeSoulTable; @@ -62,25 +62,28 @@ impl PhysicalPlanner for LakeSoulPhysicalPlanner { Ok(provider) => { let physical_input = self.create_physical_plan(input, session_state).await?; - if lakesoul_table.primary_keys().is_empty() { - if !lakesoul_table + if lakesoul_table.primary_keys().is_empty() + && !lakesoul_table .schema() .logically_equivalent_names_and_types(&Schema::from(input.schema().as_ref())) - { - return Err(DataFusionError::Plan( - // Return an error if schema of the input query does not match with the table schema. - "Inserting query must have the same schema with the table.".to_string(), - )); - } - } - let physical_input = if !lakesoul_table.primary_keys().is_empty() || !lakesoul_table.range_partitions().is_empty() { + { + return Err(DataFusionError::Plan( + // Return an error if schema of the input query does not match with the table schema. + "Inserting query must have the same schema with the table.".to_string(), + )); + } + let physical_input = if !lakesoul_table.primary_keys().is_empty() + || !lakesoul_table.range_partitions().is_empty() + { let input_schema = physical_input.schema(); let input_dfschema = input.as_ref().schema(); let sort_expr = column_names_to_physical_sort_expr( [ - lakesoul_table.range_partitions().clone(), + lakesoul_table.range_partitions().clone(), lakesoul_table.primary_keys().clone(), - ].concat().as_slice(), + ] + .concat() + .as_slice(), input_dfschema, &input_schema, session_state, @@ -91,8 +94,9 @@ impl PhysicalPlanner for LakeSoulPhysicalPlanner { &input_schema, session_state, )?; - - let hash_partitioning = Partitioning::Hash(hash_partitioning_expr, lakesoul_table.hash_bucket_num()); + + let hash_partitioning = + Partitioning::Hash(hash_partitioning_expr, lakesoul_table.hash_bucket_num()); let range_partitioning_expr = column_names_to_physical_expr( lakesoul_table.range_partitions(), input_dfschema, @@ -100,7 +104,11 @@ impl PhysicalPlanner for LakeSoulPhysicalPlanner { session_state, )?; let sort_exec = Arc::new(SortExec::new(sort_expr, physical_input)); - Arc::new(RepartitionByRangeAndHashExec::try_new(sort_exec, range_partitioning_expr, hash_partitioning)?) + Arc::new(RepartitionByRangeAndHashExec::try_new( + sort_exec, + range_partitioning_expr, + hash_partitioning, + )?) } else { physical_input }; diff --git a/rust/lakesoul-datafusion/src/test/benchmarks/tpch/mod.rs b/rust/lakesoul-datafusion/src/test/benchmarks/tpch/mod.rs index 6c26ced7d..31e2dafc5 100644 --- a/rust/lakesoul-datafusion/src/test/benchmarks/tpch/mod.rs +++ b/rust/lakesoul-datafusion/src/test/benchmarks/tpch/mod.rs @@ -3,56 +3,29 @@ // SPDX-License-Identifier: Apache-2.0 mod run; -use arrow::datatypes::{Schema, SchemaBuilder, Field, DataType}; +use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder}; pub const TPCH_TABLES: &[&str] = &[ - "part", "supplier", "partsupp", "customer", - "orders", - "lineitem", - "nation", "region", + "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", ]; pub fn get_tbl_tpch_table_primary_keys(table: &str) -> Vec { match table { - "part" => vec![ - String::from("p_partkey"), - String::from("p_name"), - ], + "part" => vec![String::from("p_partkey"), String::from("p_name")], - "supplier" => vec![ - String::from("s_suppkey"), - String::from("s_name"), - ], + "supplier" => vec![String::from("s_suppkey"), String::from("s_name")], - "partsupp" => vec![ - String::from("ps_partkey"), - String::from("ps_suppkey"), - ], + "partsupp" => vec![String::from("ps_partkey"), String::from("ps_suppkey")], - "customer" => vec![ - String::from("c_custkey"), - String::from("c_name"), - ], + "customer" => vec![String::from("c_custkey"), String::from("c_name")], - "orders" => vec![ - String::from("o_orderkey"), - String::from("o_custkey"), - ], + "orders" => vec![String::from("o_orderkey"), String::from("o_custkey")], - "lineitem" => vec![ - String::from("l_orderkey"), - String::from("l_partkey"), - ], + "lineitem" => vec![String::from("l_orderkey"), String::from("l_partkey")], - "nation" => vec![ - String::from("n_nationkey"), - String::from("n_name"), - ], + "nation" => vec![String::from("n_nationkey"), String::from("n_name")], - "region" => vec![ - String::from("r_regionkey"), - String::from("r_name"), - ], + "region" => vec![String::from("r_regionkey"), String::from("r_name")], _ => unimplemented!(), } @@ -62,37 +35,27 @@ pub fn get_tbl_tpch_table_range_partitions(table: &str) -> Vec { match table { "part" => vec![], - "supplier" => vec![ - String::from("s_nationkey"), - ], + "supplier" => vec![String::from("s_nationkey")], "partsupp" => vec![], - "customer" => vec![ - String::from("c_nationkey"), - ], + "customer" => vec![String::from("c_nationkey")], "orders" => vec![ // String::from("o_orderdate"), String::from("o_orderpriority"), ], - "lineitem" => vec![ - ], + "lineitem" => vec![], - "nation" => vec![ - String::from("n_regionkey"), - ], + "nation" => vec![String::from("n_regionkey")], - "region" => vec![ - ], + "region" => vec![], _ => unimplemented!(), } } - - /// The `.tbl` file contains a trailing column pub fn get_tbl_tpch_table_schema(table: &str) -> Schema { let mut schema = SchemaBuilder::from(get_tpch_table_schema(table).fields); @@ -195,5 +158,3 @@ pub fn get_tpch_table_schema(table: &str) -> Schema { _ => unimplemented!(), } } - - diff --git a/rust/lakesoul-datafusion/src/test/catalog_tests.rs b/rust/lakesoul-datafusion/src/test/catalog_tests.rs index 2b4517c72..1f863d780 100644 --- a/rust/lakesoul-datafusion/src/test/catalog_tests.rs +++ b/rust/lakesoul-datafusion/src/test/catalog_tests.rs @@ -49,7 +49,7 @@ mod catalog_tests { namespace: { let mut v = String::with_capacity(5); for _ in 0..10 { - v.push((&mut rng).gen_range('a'..'z')); + v.push(rng.gen_range('a'..'z')); } format!("{prefix}_{v}") }, @@ -74,7 +74,7 @@ mod catalog_tests { let table_name = { let mut v = String::with_capacity(8); for _ in 0..10 { - v.push((&mut rng).gen_range('a'..'z')); + v.push(rng.gen_range('a'..'z')); } v }; @@ -264,7 +264,7 @@ mod catalog_tests { let q = format!("show columns from test_catalog_sql.{}.{}", np.namespace, name); let df = sc.sql(&q).await.unwrap(); let record = df.collect().await.unwrap(); - assert!(record.len() > 0); + assert!(!record.is_empty()); } { // test select diff --git a/rust/lakesoul-datafusion/src/test/integration_tests.rs b/rust/lakesoul-datafusion/src/test/integration_tests.rs index 31d4138ed..0bbe076a0 100644 --- a/rust/lakesoul-datafusion/src/test/integration_tests.rs +++ b/rust/lakesoul-datafusion/src/test/integration_tests.rs @@ -5,26 +5,36 @@ mod integration_tests { use std::{path::Path, sync::Arc}; - use datafusion::{execution::context::SessionContext, datasource::{TableProvider, file_format::{FileFormat, csv::CsvFormat}, listing::{ListingOptions, ListingTableUrl, ListingTableConfig, ListingTable}}}; + use datafusion::{ + datasource::{ + file_format::{csv::CsvFormat, FileFormat}, + listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, + TableProvider, + }, + execution::context::SessionContext, + }; use lakesoul_io::lakesoul_io_config::{create_session_context_with_planner, LakeSoulIOConfigBuilder}; use lakesoul_metadata::MetaDataClient; - use crate::{catalog::{create_io_config_builder, create_table}, error::{LakeSoulError, Result}, lakesoul_table::LakeSoulTable, planner::query_planner::LakeSoulQueryPlanner, test::benchmarks::tpch::get_tbl_tpch_table_range_partitions}; - use crate::test::benchmarks::tpch::{TPCH_TABLES, get_tbl_tpch_table_schema, get_tpch_table_schema, get_tbl_tpch_table_primary_keys}; - - async fn get_table( - ctx: &SessionContext, - table: &str, - ) -> Result> { + use crate::test::benchmarks::tpch::{ + get_tbl_tpch_table_primary_keys, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_TABLES, + }; + use crate::{ + catalog::{create_io_config_builder, create_table}, + error::{LakeSoulError, Result}, + lakesoul_table::LakeSoulTable, + planner::query_planner::LakeSoulQueryPlanner, + test::benchmarks::tpch::get_tbl_tpch_table_range_partitions, + }; + + async fn get_table(ctx: &SessionContext, table: &str) -> Result> { let path = get_tpch_data_path()?; // Obtain a snapshot of the SessionState let state = ctx.state(); let (format, path, extension): (Arc, String, &'static str) = { let path = format!("{path}/{table}.tbl"); - let format = CsvFormat::default() - .with_delimiter(b'|') - .with_has_header(false); + let format = CsvFormat::default().with_delimiter(b'|').with_has_header(false); (Arc::new(format), path, ".tbl") }; @@ -36,16 +46,12 @@ mod integration_tests { let table_path = ListingTableUrl::parse(path)?; let config = ListingTableConfig::new(table_path).with_listing_options(options); - let config = { - config.with_schema(Arc::new(get_tbl_tpch_table_schema(table))) - }; + let config = { config.with_schema(Arc::new(get_tbl_tpch_table_schema(table))) }; Ok(Arc::new(ListingTable::try_new(config)?)) - } - + fn get_tpch_data_path() -> Result { - let path = - std::env::var("TPCH_DATA").unwrap_or_else(|_| "benchmarks/data".to_string()); + let path = std::env::var("TPCH_DATA").unwrap_or_else(|_| "benchmarks/data".to_string()); if !Path::new(&path).exists() { return Err(LakeSoulError::Internal(format!( "Benchmark data not found (set TPCH_DATA env var to override): {}", @@ -55,19 +61,17 @@ mod integration_tests { Ok(path) } - #[tokio::test] async fn load_tpch_data() -> Result<()> { let client = Arc::new(MetaDataClient::from_env().await?); let builder = create_io_config_builder(client.clone(), None, false, "default").await?; let ctx = create_session_context_with_planner(&mut builder.clone().build(), Some(LakeSoulQueryPlanner::new_ref()))?; - + for table in TPCH_TABLES { let table_provider = get_table(&ctx, table).await?; ctx.register_table(*table, table_provider)?; - let dataframe = ctx.sql(format!("select * from {}", table).as_str()) - .await?; + let dataframe = ctx.sql(format!("select * from {}", table).as_str()).await?; let schema = get_tpch_table_schema(table); @@ -85,4 +89,4 @@ mod integration_tests { Ok(()) } -} \ No newline at end of file +} diff --git a/rust/lakesoul-datafusion/src/test/mod.rs b/rust/lakesoul-datafusion/src/test/mod.rs index d7d9b6b2c..72576dee8 100644 --- a/rust/lakesoul-datafusion/src/test/mod.rs +++ b/rust/lakesoul-datafusion/src/test/mod.rs @@ -2,8 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -use std::sync::Arc; use arrow::array::RecordBatch; +use std::sync::Arc; use tracing::debug; use lakesoul_metadata::MetaDataClient; @@ -38,31 +38,30 @@ fn init() { fn assert_batches_eq(table_name: &str, expected: &[&str], results: &[RecordBatch]) { // let expected_lines: Vec = // expected.iter().map(|&s| s.into()).collect(); - let (schema, remain)= expected.split_at(3); + let (schema, remain) = expected.split_at(3); let (expected, end) = remain.split_at(remain.len() - 1); let mut expected = Vec::from(expected); - + expected.sort(); - + let expected_lines = [schema, &expected, end].concat(); - let formatted = datafusion::arrow::util::pretty::pretty_format_batches(results) .unwrap() .to_string(); let actual_lines: Vec<&str> = formatted.trim().lines().collect(); - let (schema, remain)= actual_lines.split_at(3); + let (schema, remain) = actual_lines.split_at(3); let (result, end) = remain.split_at(remain.len() - 1); let mut result = Vec::from(result); - + result.sort(); - + let result = [schema, &result, end].concat(); assert_eq!( - expected_lines, result, + expected_lines, result, "\n\n{}\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", table_name, expected_lines, result ); -} \ No newline at end of file +} diff --git a/rust/lakesoul-datafusion/src/test/upsert_tests.rs b/rust/lakesoul-datafusion/src/test/upsert_tests.rs index d68a43b06..5f91239bb 100644 --- a/rust/lakesoul-datafusion/src/test/upsert_tests.rs +++ b/rust/lakesoul-datafusion/src/test/upsert_tests.rs @@ -963,15 +963,7 @@ mod upsert_with_io_config_tests { None, builder.clone(), &[ - "+-----+", - "| age |", - "+-----+", - "| 1 |", - "| 2 |", - "| |", - "| |", - "| |", - "| |", + "+-----+", "| age |", "+-----+", "| 1 |", "| 2 |", "| |", "| |", "| |", "| |", "+-----+", ], ); @@ -2297,14 +2289,8 @@ mod upsert_with_metadata_tests { .map(|name| Field::new(name, DataType::Int32, true)) .collect::>(), )), - vec![ - "hash1".to_string(), - "hash2".to_string(), - ], - vec![ - "range1".to_string(), - "range2".to_string(), - ], + vec!["hash1".to_string(), "hash2".to_string()], + vec!["range1".to_string(), "range2".to_string()], client.clone(), ) .await?; @@ -2644,15 +2630,7 @@ mod upsert_with_metadata_tests { None, client.clone(), &[ - "+-----+", - "| age |", - "+-----+", - "| 1 |", - "| 2 |", - "| |", - "| |", - "| |", - "| |", + "+-----+", "| age |", "+-----+", "| 1 |", "| 2 |", "| |", "| |", "| |", "| |", "+-----+", ], ) diff --git a/rust/lakesoul-io-c/Cargo.toml b/rust/lakesoul-io-c/Cargo.toml index 00fafa27f..2e8b5808f 100644 --- a/rust/lakesoul-io-c/Cargo.toml +++ b/rust/lakesoul-io-c/Cargo.toml @@ -16,6 +16,16 @@ arrow = { workspace = true, features = ["ffi"] } tokio = { version = "1", features = ["full"] } serde_json = "1.0" serde = { version = "1.0", default-features = false, features = ["derive", "std"], optional = true } +prost = "0.12.3" +log = "0.4.20" + +[target.'cfg(target_os = "windows")'.dependencies] +datafusion-substrait = { workspace = true } + +[target.'cfg(not(target_os = "windows"))'.dependencies] +datafusion-substrait = { workspace = true, features = ["protoc"] } + + [features] hdfs = ["lakesoul-io/hdfs"] diff --git a/rust/lakesoul-io-c/lakesoul_c_bindings.h b/rust/lakesoul-io-c/lakesoul_c_bindings.h index 77f12ef19..9528a4afc 100644 --- a/rust/lakesoul-io-c/lakesoul_c_bindings.h +++ b/rust/lakesoul-io-c/lakesoul_c_bindings.h @@ -68,6 +68,10 @@ IOConfigBuilder *lakesoul_config_builder_add_single_aux_sort_column(IOConfigBuil IOConfigBuilder *lakesoul_config_builder_add_filter(IOConfigBuilder *builder, const char *filter); +IOConfigBuilder *lakesoul_config_builder_add_filter_proto(IOConfigBuilder *builder, + c_ptrdiff_t proto_addr, + int32_t len); + IOConfigBuilder *lakesoul_config_builder_set_schema(IOConfigBuilder *builder, c_ptrdiff_t schema_addr); diff --git a/rust/lakesoul-io-c/src/lib.rs b/rust/lakesoul-io-c/src/lib.rs index 0d707ef92..d69af2a11 100644 --- a/rust/lakesoul-io-c/src/lib.rs +++ b/rust/lakesoul-io-c/src/lib.rs @@ -17,11 +17,14 @@ pub use arrow::array::StructArray; use arrow::datatypes::Schema; use arrow::ffi::from_ffi; pub use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; +use datafusion_substrait::substrait::proto::Plan; +use prost::Message; use tokio::runtime::{Builder, Runtime}; use lakesoul_io::lakesoul_io_config::{LakeSoulIOConfig, LakeSoulIOConfigBuilder}; use lakesoul_io::lakesoul_reader::{LakeSoulReader, RecordBatch, Result, SyncSendableMutableLakeSoulReader}; use lakesoul_io::lakesoul_writer::SyncSendableMutableLakeSoulWriter; +use log::debug; #[repr(C)] pub struct CResult { @@ -148,6 +151,21 @@ pub extern "C" fn lakesoul_config_builder_add_filter( } } +#[no_mangle] +pub extern "C" fn lakesoul_config_builder_add_filter_proto( + builder: NonNull, + proto_addr: c_ptrdiff_t, + len: i32, +) -> NonNull { + unsafe { + debug!("proto_addr: {:#x}, len:{}", proto_addr, len); + let dst: &mut [u8] = slice::from_raw_parts_mut(proto_addr as *mut u8, len as usize); + let plan = Plan::decode(&*dst).unwrap(); + debug!("{:#?}", plan); + convert_to_opaque(from_opaque::(builder).with_filter_proto(plan)) + } +} + #[no_mangle] pub extern "C" fn lakesoul_config_builder_set_schema( builder: NonNull, diff --git a/rust/lakesoul-io/Cargo.toml b/rust/lakesoul-io/Cargo.toml index 865e767ae..de653130b 100644 --- a/rust/lakesoul-io/Cargo.toml +++ b/rust/lakesoul-io/Cargo.toml @@ -40,6 +40,7 @@ parking_lot = "0.12.1" half = { workspace = true } log = "0.4.20" anyhow = { workspace = true, features = [] } +prost = "0.12.3" [features] @@ -47,6 +48,15 @@ hdfs = ["dep:hdrs"] simd = ["datafusion/simd", "arrow/simd", "arrow-array/simd"] default = [] +[target.'cfg(target_os = "windows")'.dependencies] +datafusion-substrait = { workspace = true } + +[target.'cfg(not(target_os = "windows"))'.dependencies] +datafusion-substrait = { workspace = true, features = ["protoc"] } + + + + [dev-dependencies] tempfile = "3.3.0" comfy-table = "6.0" diff --git a/rust/lakesoul-io/src/datasource/file_format.rs b/rust/lakesoul-io/src/datasource/file_format.rs index 662ba3181..3f40cacb9 100644 --- a/rust/lakesoul-io/src/datasource/file_format.rs +++ b/rust/lakesoul-io/src/datasource/file_format.rs @@ -101,13 +101,22 @@ impl FileFormat for LakeSoulParquetFormat { let projection = conf.projection.clone(); let target_schema = project_schema(&table_schema, projection.as_ref())?; - let merged_projection = compute_project_column_indices(table_schema.clone(), target_schema.clone(), self.conf.primary_keys_slice()); + let merged_projection = compute_project_column_indices( + table_schema.clone(), + target_schema.clone(), + self.conf.primary_keys_slice(), + ); let merged_schema = project_schema(&table_schema, merged_projection.as_ref())?; // files to read - let flatten_conf = - flatten_file_scan_config(state, self.parquet_format.clone(), conf, self.conf.primary_keys_slice(), target_schema.clone()).await?; - + let flatten_conf = flatten_file_scan_config( + state, + self.parquet_format.clone(), + conf, + self.conf.primary_keys_slice(), + target_schema.clone(), + ) + .await?; let merge_exec = Arc::new(MergeParquetExec::new( merged_schema.clone(), @@ -116,9 +125,8 @@ impl FileFormat for LakeSoulParquetFormat { self.parquet_format.metadata_size_hint(state.config_options()), self.conf.clone(), )?); - - if target_schema.fields().len() < merged_schema.fields().len() { + if target_schema.fields().len() < merged_schema.fields().len() { let mut projection_expr = vec![]; for field in target_schema.fields() { projection_expr.push(( @@ -169,8 +177,7 @@ pub async fn flatten_file_scan_config( let statistics = format .infer_stats(state, &store, file_schema.clone(), &file.object_meta) .await?; - let projection = - compute_project_column_indices(file_schema.clone(), target_schema.clone(), primary_keys); + let projection = compute_project_column_indices(file_schema.clone(), target_schema.clone(), primary_keys); let limit = conf.limit; let table_partition_cols = conf.table_partition_cols.clone(); let output_ordering = conf.output_ordering.clone(); diff --git a/rust/lakesoul-io/src/datasource/listing.rs b/rust/lakesoul-io/src/datasource/listing.rs index 796c863e4..f7dc92a79 100644 --- a/rust/lakesoul-io/src/datasource/listing.rs +++ b/rust/lakesoul-io/src/datasource/listing.rs @@ -38,17 +38,16 @@ impl Debug for LakeSoulListingTable { } impl LakeSoulListingTable { - pub async fn new_with_config_and_format( session_state: &SessionState, lakesoul_io_config: LakeSoulIOConfig, file_format: Arc, as_sink: bool, ) -> Result { - - let (file_schema, listing_table) = listing_table_from_lakesoul_io_config(session_state, lakesoul_io_config.clone(), file_format, as_sink).await?; - let file_schema = file_schema - .ok_or_else(|| DataFusionError::Internal("No schema provided.".into()))?; + let (file_schema, listing_table) = + listing_table_from_lakesoul_io_config(session_state, lakesoul_io_config.clone(), file_format, as_sink) + .await?; + let file_schema = file_schema.ok_or_else(|| DataFusionError::Internal("No schema provided.".into()))?; let table_schema = Self::compute_table_schema(file_schema, lakesoul_io_config.schema()); Ok(Self { @@ -69,14 +68,13 @@ impl LakeSoulListingTable { pub fn compute_table_schema(file_schema: SchemaRef, target_schema: SchemaRef) -> SchemaRef { let target_schema = uniform_schema(target_schema); let mut builder = SchemaBuilder::from(target_schema.fields()); - for field in file_schema.fields() { - if target_schema.field_with_name(field.name()).is_err() { - builder.push(field.clone()); - } + for field in file_schema.fields() { + if target_schema.field_with_name(field.name()).is_err() { + builder.push(field.clone()); } + } Arc::new(builder.finish()) } - } #[async_trait] diff --git a/rust/lakesoul-io/src/datasource/physical_plan/defatul_column.rs b/rust/lakesoul-io/src/datasource/physical_plan/defatul_column.rs index 512ae6f73..51c6d6042 100644 --- a/rust/lakesoul-io/src/datasource/physical_plan/defatul_column.rs +++ b/rust/lakesoul-io/src/datasource/physical_plan/defatul_column.rs @@ -2,8 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -use std::{any::Any, collections::HashMap}; use std::sync::Arc; +use std::{any::Any, collections::HashMap}; use arrow_schema::SchemaRef; use datafusion::{ @@ -19,19 +19,19 @@ use crate::default_column_stream::DefaultColumnStream; pub struct DefaultColumnExec { input: Arc, target_schema: SchemaRef, - default_column_value: Arc> + default_column_value: Arc>, } impl DefaultColumnExec { pub fn new( - input: Arc, + input: Arc, target_schema: SchemaRef, - default_column_value: Arc> + default_column_value: Arc>, ) -> Result { Ok(Self { input, target_schema, - default_column_value + default_column_value, }) } } diff --git a/rust/lakesoul-io/src/datasource/physical_plan/merge.rs b/rust/lakesoul-io/src/datasource/physical_plan/merge.rs index af771cc55..9a372b6d6 100644 --- a/rust/lakesoul-io/src/datasource/physical_plan/merge.rs +++ b/rust/lakesoul-io/src/datasource/physical_plan/merge.rs @@ -15,10 +15,12 @@ use datafusion::{ physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PhysicalExpr, SendableRecordBatchStream}, }; use datafusion_common::{DFSchemaRef, DataFusionError, Result}; +use datafusion_substrait::substrait::proto::Plan; +use log::debug; -use crate::filter::parser::Parser as FilterParser; use crate::default_column_stream::empty_schema_stream::EmptySchemaStream; use crate::default_column_stream::DefaultColumnStream; +use crate::filter::parser::Parser as FilterParser; use crate::lakesoul_io_config::LakeSoulIOConfig; use crate::sorted_merge::merge_operator::MergeOperator; use crate::sorted_merge::sorted_stream_merger::{SortedStream, SortedStreamMerger}; @@ -87,7 +89,6 @@ impl MergeParquetExec { io_config: LakeSoulIOConfig, default_column_value: Arc>, ) -> Result { - let primary_keys = Arc::new(io_config.primary_keys); let merge_operators = Arc::new(io_config.merge_operators); @@ -100,7 +101,6 @@ impl MergeParquetExec { }) } - pub fn primary_keys(&self) -> Arc> { self.primary_keys.clone() } @@ -251,27 +251,43 @@ fn schema_intersection(df_schema: DFSchemaRef, request_schema: SchemaRef) -> Vec exprs } +pub fn convert_filter(df: &DataFrame, filter_str: Vec, filter_protos: Vec) -> Result> { + let arrow_schema = Arc::new(Schema::from(df.schema())); + let mut str_filters = vec![]; + let mut proto_filters = vec![]; + for f in &filter_str { + let filter = FilterParser::parse(f.clone(), arrow_schema.clone())?; + str_filters.push(filter); + } + for p in &filter_protos { + let e = FilterParser::parse_proto(p)?; + proto_filters.push(e); + } + debug!("str filters: {:?}", str_filters); + debug!("proto filters: {:?}", proto_filters); + if !str_filters.is_empty() { + Ok(str_filters) + } else { + Ok(proto_filters) + } +} + pub async fn prune_filter_and_execute( df: DataFrame, request_schema: SchemaRef, - filter_str: Vec, + filters: Vec, batch_size: usize, ) -> Result { let df_schema = df.schema().clone(); - // find columns requested and prune others + // find columns requested and prune otherPlans let cols = schema_intersection(Arc::new(df_schema.clone()), request_schema.clone()); if cols.is_empty() { - Ok(Box::pin(EmptySchemaStream::new(batch_size, df.count().await?))) - } else { - // row filtering should go first since filter column may not in the selected cols - let arrow_schema = Arc::new(Schema::from(df_schema)); - let df = filter_str.iter().try_fold(df, |df, f| { - let filter = FilterParser::parse(f.clone(), arrow_schema.clone())?; - df.filter(filter) - })?; - // column pruning - let df = df.select(cols)?; - // return a stream - df.execute_stream().await + return Ok(Box::pin(EmptySchemaStream::new(batch_size, df.count().await?))); } + // row filtering should go first since filter column may not in the selected cols + let df = filters.into_iter().try_fold(df, |df, f| df.filter(f))?; + // column pruning + let df = df.select(cols)?; + // return a stream + df.execute_stream().await } diff --git a/rust/lakesoul-io/src/filter/parser.rs b/rust/lakesoul-io/src/filter/parser.rs index 34c0bb137..230ee26a0 100644 --- a/rust/lakesoul-io/src/filter/parser.rs +++ b/rust/lakesoul-io/src/filter/parser.rs @@ -2,14 +2,33 @@ // // SPDX-License-Identifier: Apache-2.0 +use std::collections::HashMap; +use std::ops::Not; +use std::str::FromStr; +use std::sync::Arc; + use anyhow::anyhow; -use arrow_schema::{DataType, Field, Fields, SchemaRef}; -use datafusion::logical_expr::Expr; +use arrow_schema::{DataType, Field, Fields, SchemaRef, TimeUnit}; +use datafusion::logical_expr::{expr, BinaryExpr, BuiltinScalarFunction, Expr, Operator}; use datafusion::prelude::col; use datafusion::scalar::ScalarValue; -use datafusion_common::{DataFusionError, Result}; -use std::ops::Not; -use std::sync::Arc; +use datafusion_common::{not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result}; +use datafusion_substrait::substrait; +use datafusion_substrait::substrait::proto::expression::field_reference::ReferenceType::DirectReference; +use datafusion_substrait::substrait::proto::expression::literal::LiteralType; +use datafusion_substrait::substrait::proto::expression::reference_segment::ReferenceType::StructField; +use datafusion_substrait::substrait::proto::expression::{Literal, RexType}; +use datafusion_substrait::substrait::proto::extensions::simple_extension_declaration::MappingType; +use datafusion_substrait::substrait::proto::function_argument::ArgType; +use datafusion_substrait::substrait::proto::r#type::Nullability; +use datafusion_substrait::substrait::proto::read_rel::ReadType; +use datafusion_substrait::substrait::proto::rel::RelType; +use datafusion_substrait::substrait::proto::{plan_rel, r#type, Expression, Plan, Rel, Type}; +use datafusion_substrait::variation_const::{ + DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF, DEFAULT_CONTAINER_TYPE_REF, + DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, + TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF, +}; pub struct Parser {} @@ -207,14 +226,405 @@ impl Parser { }; Ok(res) } + + pub(crate) fn parse_proto(plan: &Plan) -> Result { + let function_extension = plan + .extensions + .iter() + .map(|e| match &e.mapping_type { + Some(ext) => match ext { + MappingType::ExtensionFunction(ext_f) => Ok((ext_f.function_anchor, &ext_f.name)), + _ => not_impl_err!("Extension type not supported: {ext:?}"), + }, + None => not_impl_err!("Cannot parse empty extension"), + }) + .collect::>>()?; + // Parse relations + match plan.relations.len() { + 1 => match plan.relations[0].rel_type.as_ref() { + Some(rt) => match rt { + plan_rel::RelType::Rel(rel) => Ok(Parser::parse_rel(rel, &function_extension)?), + plan_rel::RelType::Root(root) => Ok(Parser::parse_rel( + root.input + .as_ref() + .ok_or(DataFusionError::Substrait("wrong root".to_string()))?, + &function_extension, + )?), + }, + None => plan_err!("Cannot parse plan relation: None"), + }, + _ => not_impl_err!( + "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", + plan.relations.len() + ), + } + } + + fn parse_rel(rel: &Rel, extensions: &HashMap) -> Result { + match &rel.rel_type { + Some(RelType::Read(read)) => match &read.as_ref().read_type { + None => { + not_impl_err!("unsupported") + } + Some(ReadType::NamedTable(_nt)) => { + let named_struct = read + .base_schema + .clone() + .ok_or(DataFusionError::Substrait("wrong name table".to_string()))?; + let st = named_struct + .r#struct + .ok_or(DataFusionError::Substrait("struct get failed".to_string()))?; + let typs = st.types; + let names = named_struct.names; + let mut fields = vec![]; + for (typ, name) in typs.into_iter().zip(names.into_iter()) { + let data_type = from_substrait_type(&typ)?; + fields.push(Field::new(name, data_type.0, from_nullability(data_type.1))) + } + let sma = arrow_schema::Schema::new(fields); + let df_schema = DFSchema::try_from(sma)?; + let e = read + .filter + .as_ref() + .ok_or(DataFusionError::Substrait("wrong filter".to_string()))?; + Parser::parse_rex(e.as_ref(), &df_schema, extensions) + } + Some(_) => { + not_impl_err!("un supported") + } + }, + _ => not_impl_err!("un supported"), + } + } + + // recursion + fn parse_rex(e: &Expression, input_schema: &DFSchema, extensions: &HashMap) -> Result { + match &e.rex_type { + Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => not_impl_err!("Direct reference StructField with child is not supported"), + None => { + let column = input_schema.field(x.field as usize).qualified_column(); + Ok(Expr::Column(Column { + relation: column.relation, + name: column.name, + })) + } + }, + _ => not_impl_err!("Direct reference with types other than StructField is not supported"), + }, + _ => not_impl_err!("unsupported field ref type"), + }, + Some(RexType::ScalarFunction(f)) => { + let fn_name = extensions.get(&f.function_reference).ok_or_else(|| { + DataFusionError::NotImplemented(format!( + "Aggregated function not found: function reference = {:?}", + f.function_reference + )) + })?; + let fn_type = scalar_function_type_from_str(fn_name)?; + match fn_type { + ScalarFunctionType::Builtin(fun) => { + let mut args = Vec::with_capacity(f.arguments.len()); + for arg in &f.arguments { + let arg_expr = match &arg.arg_type { + Some(ArgType::Value(e)) => Parser::parse_rex(e, input_schema, extensions), + _ => not_impl_err!("Aggregated function argument non-Value type not supported"), + }; + args.push(arg_expr?); + } + Ok(Expr::ScalarFunction(expr::ScalarFunction { fun, args })) + } + ScalarFunctionType::Op(op) => { + if f.arguments.len() != 2 { + return not_impl_err!("Expect two arguments for binary operator {op:?}"); + } + let lhs = &f.arguments[0].arg_type; + let rhs = &f.arguments[1].arg_type; + + match (lhs, rhs) { + (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => Ok(Expr::BinaryExpr(BinaryExpr { + left: Box::new(Parser::parse_rex(l, input_schema, extensions)?), + op, + right: Box::new(Parser::parse_rex(r, input_schema, extensions)?), + })), + (l, r) => not_impl_err!("Invalid arguments for binary expression: {l:?} and {r:?}"), + } + } + ScalarFunctionType::Not => { + let arg = f.arguments.first().ok_or_else(|| { + DataFusionError::Substrait("expect one argument for `NOT` expr".to_string()) + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = Parser::parse_rex(e, input_schema, extensions)?; + Ok(Expr::Not(Box::new(expr))) + } + _ => not_impl_err!("Invalid arguments for Not expression"), + } + } + ScalarFunctionType::IsNull => { + let arg = f.arguments.first().ok_or_else(|| { + DataFusionError::Substrait("expect one argument for `IS NULL` expr".to_string()) + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = Parser::parse_rex(e, input_schema, extensions)?; + Ok(Expr::IsNull(Box::new(expr))) + } + _ => not_impl_err!("Invalid arguments for IS NULL expression"), + } + } + ScalarFunctionType::IsNotNull => { + let arg = f.arguments.first().ok_or_else(|| { + DataFusionError::Substrait("expect one argument for `IS NOT NULL` expr".to_string()) + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = Parser::parse_rex(e, input_schema, extensions)?; + Ok(Expr::IsNotNull(Box::new(expr))) + } + _ => { + not_impl_err!("Invalid arguments for IS NOT NULL expression") + } + } + } + _ => not_impl_err!("not implemented"), + } + } + Some(RexType::Literal(lit)) => { + let scalar_value = from_substrait_literal(lit)?; + Ok(Expr::Literal(scalar_value)) + } + _ => unimplemented!(), + } + } +} + +enum ScalarFunctionType { + Builtin(BuiltinScalarFunction), + Op(Operator), + /// [Expr::Not] + Not, + /// [Expr::Like] Used for filtering rows based on the given wildcard pattern. Case-sensitive + Like, + /// [Expr::Like] Case insensitive operator counterpart of `Like` + ILike, + /// [Expr::IsNull] + IsNull, + /// [Expr::IsNotNull] + IsNotNull, +} + +pub fn name_to_op(name: &str) -> Result { + match name { + "equal" => Ok(Operator::Eq), + "not_equal" => Ok(Operator::NotEq), + "lt" => Ok(Operator::Lt), + "lte" => Ok(Operator::LtEq), + "gt" => Ok(Operator::Gt), + "gte" => Ok(Operator::GtEq), + "add" => Ok(Operator::Plus), + "subtract" => Ok(Operator::Minus), + "multiply" => Ok(Operator::Multiply), + "divide" => Ok(Operator::Divide), + "mod" => Ok(Operator::Modulo), + "and" => Ok(Operator::And), + "or" => Ok(Operator::Or), + "is_distinct_from" => Ok(Operator::IsDistinctFrom), + "is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom), + "regex_match" => Ok(Operator::RegexMatch), + "regex_imatch" => Ok(Operator::RegexIMatch), + "regex_not_match" => Ok(Operator::RegexNotMatch), + "regex_not_imatch" => Ok(Operator::RegexNotIMatch), + "bitwise_and" => Ok(Operator::BitwiseAnd), + "bitwise_or" => Ok(Operator::BitwiseOr), + "str_concat" => Ok(Operator::StringConcat), + "at_arrow" => Ok(Operator::AtArrow), + "arrow_at" => Ok(Operator::ArrowAt), + "bitwise_xor" => Ok(Operator::BitwiseXor), + "bitwise_shift_right" => Ok(Operator::BitwiseShiftRight), + "bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft), + _ => not_impl_err!("Unsupported function name: {name:?}"), + } +} + +fn scalar_function_type_from_str(name: &str) -> Result { + let (name, _) = name + .split_once(':') + .ok_or(DataFusionError::Substrait("wrong func type".to_string()))?; + if let Ok(op) = datafusion_substrait::logical_plan::consumer::name_to_op(name) { + return Ok(ScalarFunctionType::Op(op)); + } + + if let Ok(fun) = BuiltinScalarFunction::from_str(name) { + return Ok(ScalarFunctionType::Builtin(fun)); + } + + match name { + "not" => Ok(ScalarFunctionType::Not), + "like" => Ok(ScalarFunctionType::Like), + "ilike" => Ok(ScalarFunctionType::ILike), + "is_null" => Ok(ScalarFunctionType::IsNull), + "is_not_null" => Ok(ScalarFunctionType::IsNotNull), + others => not_impl_err!("Unsupported function name: {others:?}"), + } +} + +fn from_substrait_literal(lit: &Literal) -> Result { + let scalar_value = match &lit.literal_type { + Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), + Some(LiteralType::I8(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => ScalarValue::Int8(Some(*n as i8)), + UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt8(Some(*n as u8)), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::I16(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => ScalarValue::Int16(Some(*n as i16)), + UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt16(Some(*n as u16)), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::I32(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)), + UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(*n as u32)), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::I64(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)), + UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(*n as u64)), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), + Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), + Some(LiteralType::Timestamp(t)) => match lit.type_variation_reference { + TIMESTAMP_SECOND_TYPE_REF => ScalarValue::TimestampSecond(Some(*t), None), + TIMESTAMP_MILLI_TYPE_REF => ScalarValue::TimestampMillisecond(Some(*t), None), + TIMESTAMP_MICRO_TYPE_REF => ScalarValue::TimestampMicrosecond(Some(*t), None), + TIMESTAMP_NANO_TYPE_REF => ScalarValue::TimestampNanosecond(Some(*t), None), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), + Some(LiteralType::String(s)) => match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Utf8(Some(s.clone())), + LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeUtf8(Some(s.clone())), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::Binary(b)) => match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Binary(Some(b.clone())), + LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeBinary(Some(b.clone())), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::FixedBinary(b)) => ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())), + Some(LiteralType::Decimal(d)) => { + let value: [u8; 16] = d.value.clone().try_into().or(Err(DataFusionError::Substrait( + "Failed to parse decimal value".to_string(), + )))?; + let p = d + .precision + .try_into() + .map_err(|e| DataFusionError::Substrait(format!("Failed to parse decimal precision: {e}")))?; + let s = d + .scale + .try_into() + .map_err(|e| DataFusionError::Substrait(format!("Failed to parse decimal scale: {e}")))?; + ScalarValue::Decimal128(Some(std::primitive::i128::from_le_bytes(value)), p, s) + } + Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?, + _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), + }; + + Ok(scalar_value) +} + +fn from_substrait_null(null_type: &Type) -> Result { + if let Some(kind) = &null_type.kind { + match kind { + r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)), + r#type::Kind::I8(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(ScalarValue::Int8(None)), + UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt8(None)), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + r#type::Kind::I16(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(ScalarValue::Int16(None)), + UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt16(None)), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + r#type::Kind::I32(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(ScalarValue::Int32(None)), + UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt32(None)), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + r#type::Kind::I64(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(ScalarValue::Int64(None)), + UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt64(None)), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + r#type::Kind::Fp32(_) => Ok(ScalarValue::Float32(None)), + r#type::Kind::Fp64(_) => Ok(ScalarValue::Float64(None)), + r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_REF => Ok(ScalarValue::TimestampSecond(None, None)), + TIMESTAMP_MILLI_TYPE_REF => Ok(ScalarValue::TimestampMillisecond(None, None)), + TIMESTAMP_MICRO_TYPE_REF => Ok(ScalarValue::TimestampMicrosecond(None, None)), + TIMESTAMP_NANO_TYPE_REF => Ok(ScalarValue::TimestampNanosecond(None, None)), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + r#type::Kind::Date(date) => match date.type_variation_reference { + DATE_32_TYPE_REF => Ok(ScalarValue::Date32(None)), + DATE_64_TYPE_REF => Ok(ScalarValue::Date64(None)), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + r#type::Kind::Binary(binary) => match binary.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Binary(None)), + LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeBinary(None)), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + // FixedBinary is not supported because `None` doesn't have length + r#type::Kind::String(string) => match string.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Utf8(None)), + LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeUtf8(None)), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"), + }, + r#type::Kind::Decimal(d) => Ok(ScalarValue::Decimal128(None, d.precision as u8, d.scale as i8)), + _ => not_impl_err!("Unsupported Substrait type: {kind:?}"), + } + } else { + not_impl_err!("Null type without kind is not supported") + } } fn qualified_expr(expr_str: &str, schema: SchemaRef) -> Option<(Expr, Arc)> { if let Ok(field) = schema.field_with_name(expr_str) { - Some(( - col(datafusion::common::Column::new_unqualified(expr_str)), - Arc::new(field.clone()), - )) + Some((col(Column::new_unqualified(expr_str)), Arc::new(field.clone()))) } else { let mut expr: Option<(Expr, Arc)> = None; let mut root = "".to_owned(); @@ -229,10 +639,7 @@ fn qualified_expr(expr_str: &str, schema: SchemaRef) -> Option<(Expr, Arc expr = if let Some((folding_exp, _)) = expr { Some((folding_exp.field(field.name()), field.clone())) } else { - Some(( - col(datafusion::common::Column::new_unqualified(field.name())), - field.clone(), - )) + Some((col(Column::new_unqualified(field.name())), field.clone())) }; root = "".to_owned(); @@ -246,11 +653,99 @@ fn qualified_expr(expr_str: &str, schema: SchemaRef) -> Option<(Expr, Arc } } +fn from_substrait_type(dt: &substrait::proto::Type) -> Result<(DataType, Nullability)> { + match &dt.kind { + Some(s_kind) => match s_kind { + r#type::Kind::Bool(b) => Ok((DataType::Boolean, b.nullability())), + r#type::Kind::I8(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok((DataType::Int8, integer.nullability())), + UNSIGNED_INTEGER_TYPE_REF => Ok((DataType::UInt8, integer.nullability())), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {s_kind:?}"), + }, + r#type::Kind::I16(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok((DataType::Int16, integer.nullability())), + UNSIGNED_INTEGER_TYPE_REF => Ok((DataType::UInt16, integer.nullability())), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {s_kind:?}"), + }, + r#type::Kind::I32(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok((DataType::Int32, integer.nullability())), + UNSIGNED_INTEGER_TYPE_REF => Ok((DataType::UInt32, integer.nullability())), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {s_kind:?}"), + }, + r#type::Kind::I64(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok((DataType::Int64, integer.nullability())), + UNSIGNED_INTEGER_TYPE_REF => Ok((DataType::UInt64, integer.nullability())), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {s_kind:?}"), + }, + r#type::Kind::Fp32(fp) => Ok((DataType::Float32, fp.nullability())), + r#type::Kind::Fp64(fp) => Ok((DataType::Float64, fp.nullability())), + r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_REF => Ok((DataType::Timestamp(TimeUnit::Second, None), ts.nullability())), + TIMESTAMP_MILLI_TYPE_REF => Ok((DataType::Timestamp(TimeUnit::Millisecond, None), ts.nullability())), + TIMESTAMP_MICRO_TYPE_REF => Ok((DataType::Timestamp(TimeUnit::Microsecond, None), ts.nullability())), + TIMESTAMP_NANO_TYPE_REF => Ok((DataType::Timestamp(TimeUnit::Nanosecond, None), ts.nullability())), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {s_kind:?}"), + }, + r#type::Kind::Date(date) => match date.type_variation_reference { + DATE_32_TYPE_REF => Ok((DataType::Date32, date.nullability())), + DATE_64_TYPE_REF => Ok((DataType::Date64, date.nullability())), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {s_kind:?}"), + }, + r#type::Kind::Binary(binary) => match binary.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok((DataType::Binary, binary.nullability())), + LARGE_CONTAINER_TYPE_REF => Ok((DataType::LargeBinary, binary.nullability())), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {s_kind:?}"), + }, + r#type::Kind::FixedBinary(fixed) => Ok((DataType::FixedSizeBinary(fixed.length), fixed.nullability())), + r#type::Kind::String(string) => match string.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok((DataType::Utf8, string.nullability())), + LARGE_CONTAINER_TYPE_REF => Ok((DataType::LargeUtf8, string.nullability())), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {s_kind:?}"), + }, + r#type::Kind::List(_list) => { + not_impl_err!("Unsupported") + // let (inner_type, _nullablity) = + // from_substrait_type(list.r#type.as_ref().ok_or_else(|| { + // DataFusionError::Substrait( + // "List type must have inner type".to_string(), + // ) + // })?)?; + // let field = Arc::new(Field::new("list_item", inner_type, true)); + // match list.type_variation_reference { + // DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)), + // LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)), + // v => not_impl_err!( + // "Unsupported Substrait type variation {v} of type {s_kind:?}" + // )?, + // } + } + r#type::Kind::Decimal(d) => match d.type_variation_reference { + DECIMAL_128_TYPE_REF => Ok((DataType::Decimal128(d.precision as u8, d.scale as i8), d.nullability())), + DECIMAL_256_TYPE_REF => Ok((DataType::Decimal256(d.precision as u8, d.scale as i8), d.nullability())), + v => not_impl_err!("Unsupported Substrait type variation {v} of type {s_kind:?}"), + }, + _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), + }, + _ => not_impl_err!("`None` Substrait kind is not supported"), + } +} + +fn from_nullability(nullability: Nullability) -> bool { + match nullability { + Nullability::Unspecified => true, + Nullability::Nullable => true, + Nullability::Required => false, + } +} + #[cfg(test)] mod tests { - use crate::filter::parser::Parser; use std::result::Result; + use prost::Message; + + use super::*; + #[test] fn test_filter_parser() -> Result<(), String> { let s = String::from("or(lt(a.b.c, 2.0), gt(a.b.c, 3.0))"); @@ -260,4 +755,25 @@ mod tests { assert_eq!(right, "gt(a.b.c, 3.0)"); Ok(()) } + + #[test] + fn t() { + let v = [ + 10, 30, 8, 1, 18, 26, 47, 102, 117, 110, 99, 116, 105, 111, 110, 115, 95, 99, 111, 109, 112, 97, 114, 105, + 115, 111, 110, 46, 121, 97, 109, 108, 18, 19, 26, 17, 8, 1, 26, 13, 101, 113, 117, 97, 108, 58, 97, 110, + 121, 95, 97, 110, 121, 26, 83, 18, 81, 10, 79, 10, 77, 10, 2, 10, 0, 18, 28, 10, 4, 99, 111, 108, 49, 10, + 4, 99, 111, 108, 50, 18, 14, 10, 4, 42, 2, 16, 2, 10, 4, 42, 2, 16, 2, 24, 2, 26, 30, 26, 28, 26, 4, 10, 2, + 16, 2, 34, 12, 26, 10, 18, 8, 10, 4, 18, 2, 8, 1, 34, 0, 34, 6, 26, 4, 10, 2, 40, 3, 58, 9, 10, 7, 97, 95, + 116, 97, 98, 108, 101, + ]; + let plan = Plan::decode(&v[..]).unwrap(); + println!("{:?}", plan); + // let f1 = Field::new("col1", DataType::Int32, true); + // let f2 = Field::new("col2", DataType::Int32, true); + // let sma = arrow_schema::Schema::new(vec![f1, f2]); + // let sma = DFSchema::try_from(sma).unwrap(); + // println!("{:?}", plan); + // let expr = Parser::parse_proto(&plan, &sma).unwrap(); + // println!("{:#?}", expr); + } } diff --git a/rust/lakesoul-io/src/helpers.rs b/rust/lakesoul-io/src/helpers.rs index 77ef12794..b31b3631e 100644 --- a/rust/lakesoul-io/src/helpers.rs +++ b/rust/lakesoul-io/src/helpers.rs @@ -6,7 +6,16 @@ use std::{collections::HashMap, sync::Arc}; use arrow_schema::{DataType, Schema, SchemaBuilder, SchemaRef}; use datafusion::{ - datasource::{file_format::FileFormat, listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, physical_plan::FileScanConfig}, execution::context::SessionState, logical_expr::col, physical_expr::{create_physical_expr, PhysicalSortExpr}, physical_plan::PhysicalExpr, physical_planner::create_physical_sort_expr + datasource::{ + file_format::FileFormat, + listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, + physical_plan::FileScanConfig, + }, + execution::context::SessionState, + logical_expr::col, + physical_expr::{create_physical_expr, PhysicalSortExpr}, + physical_plan::PhysicalExpr, + physical_planner::create_physical_sort_expr, }; use datafusion_common::{DFSchema, DataFusionError, Result}; use object_store::path::Path; @@ -63,36 +72,29 @@ fn range_partition_to_partition_cols( .collect::>>() } -pub fn partition_desc_from_file_scan_config( - conf: &FileScanConfig -) -> Result<(String, HashMap)> { +pub fn partition_desc_from_file_scan_config(conf: &FileScanConfig) -> Result<(String, HashMap)> { if conf.table_partition_cols.is_empty() { Ok(("-5".to_string(), HashMap::default())) } else { match conf.file_groups.first().unwrap().first() { - Some(file) => Ok( - (conf - .table_partition_cols + Some(file) => Ok(( + conf.table_partition_cols .iter() .enumerate() - .map(|(idx, col)| { - format!("{}={}", col.name().clone(), file.partition_values[idx].to_string()) - }) + .map(|(idx, col)| format!("{}={}", col.name().clone(), file.partition_values[idx])) .collect::>() .join(","), HashMap::from_iter( - conf - .table_partition_cols + conf.table_partition_cols .iter() .enumerate() - .map(|(idx, col)| { - (col.name().clone(), file.partition_values[idx].to_string()) - }) - )) + .map(|(idx, col)| (col.name().clone(), file.partition_values[idx].to_string())), ), - None => Err(DataFusionError::External(format!("Invalid file_group {:?}", conf.file_groups).into())), + )), + None => Err(DataFusionError::External( + format!("Invalid file_group {:?}", conf.file_groups).into(), + )), } - } } @@ -100,7 +102,7 @@ pub async fn listing_table_from_lakesoul_io_config( session_state: &SessionState, lakesoul_io_config: LakeSoulIOConfig, file_format: Arc, - as_sink: bool + as_sink: bool, ) -> Result<(Option, Arc)> { let config = match as_sink { false => { @@ -118,7 +120,8 @@ pub async fn listing_table_from_lakesoul_io_config( let store = session_state.runtime_env().object_store(object_store_url.clone())?; let target_schema = uniform_schema(lakesoul_io_config.schema()); - let table_partition_cols = range_partition_to_partition_cols(target_schema.clone(), lakesoul_io_config.range_partitions_slice())?; + let table_partition_cols = + range_partition_to_partition_cols(target_schema.clone(), lakesoul_io_config.range_partitions_slice())?; let listing_options = ListingOptions::new(file_format.clone()) .with_file_extension(".parquet") .with_table_partition_cols(table_partition_cols); @@ -126,7 +129,13 @@ pub async fn listing_table_from_lakesoul_io_config( let mut objects = vec![]; for url in &table_paths { - objects.push(store.head(&Path::from_url_path(>::as_ref(url).path())?).await?); + objects.push( + store + .head(&Path::from_url_path( + >::as_ref(url).path(), + )?) + .await?, + ); } // Resolve the schema let resolved_schema = file_format.infer_schema(session_state, &store, &objects).await?; @@ -145,14 +154,14 @@ pub async fn listing_table_from_lakesoul_io_config( } true => { let target_schema = uniform_schema(lakesoul_io_config.schema()); - let table_partition_cols = range_partition_to_partition_cols(target_schema.clone(), lakesoul_io_config.range_partitions_slice())?; + let table_partition_cols = + range_partition_to_partition_cols(target_schema.clone(), lakesoul_io_config.range_partitions_slice())?; let listing_options = ListingOptions::new(file_format.clone()) .with_file_extension(".parquet") .with_table_partition_cols(table_partition_cols) .with_insert_mode(datafusion::datasource::listing::ListingTableInsertMode::AppendNewFiles); - let prefix = - ListingTableUrl::parse_create_local_if_not_exists(lakesoul_io_config.prefix.clone(), true)?; + let prefix = ListingTableUrl::parse_create_local_if_not_exists(lakesoul_io_config.prefix.clone(), true)?; ListingTableConfig::new(prefix) .with_listing_options(listing_options) @@ -162,4 +171,3 @@ pub async fn listing_table_from_lakesoul_io_config( Ok((config.file_schema.clone(), Arc::new(ListingTable::try_new(config)?))) } - diff --git a/rust/lakesoul-io/src/lakesoul_io_config.rs b/rust/lakesoul-io/src/lakesoul_io_config.rs index 5efe1e19e..05c6428f6 100644 --- a/rust/lakesoul-io/src/lakesoul_io_config.rs +++ b/rust/lakesoul-io/src/lakesoul_io_config.rs @@ -2,6 +2,10 @@ // // SPDX-License-Identifier: Apache-2.0 +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + use anyhow::anyhow; use arrow::error::ArrowError; use arrow_schema::{Schema, SchemaRef}; @@ -14,12 +18,10 @@ use datafusion::optimizer::push_down_filter::PushDownFilter; use datafusion::optimizer::push_down_projection::PushDownProjection; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::DataFusionError::{External, ObjectStore}; +use datafusion_substrait::substrait::proto::Plan; use derivative::Derivative; use object_store::aws::AmazonS3Builder; use object_store::{ClientOptions, RetryConfig}; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; use url::{ParseError, Url}; #[cfg(feature = "hdfs")] @@ -56,6 +58,7 @@ pub struct LakeSoulIOConfig { // filtering predicates pub(crate) filter_strs: Vec, pub(crate) filters: Vec, + pub(crate) filter_protos: Vec, // read or write batch size #[derivative(Default(value = "8192"))] pub(crate) batch_size: usize, @@ -152,7 +155,6 @@ impl LakeSoulIOConfigBuilder { self } - pub fn with_range_partitions(mut self, range_partitions: Vec) -> Self { self.config.range_partitions = range_partitions; self @@ -208,6 +210,11 @@ impl LakeSoulIOConfigBuilder { self } + pub fn with_filter_proto(mut self, filter_proto: Plan) -> Self { + self.config.filter_protos.push(filter_proto); + self + } + pub fn with_filters(mut self, filters: Vec) -> Self { self.config.filters = filters; self diff --git a/rust/lakesoul-io/src/lakesoul_reader.rs b/rust/lakesoul-io/src/lakesoul_reader.rs index 6546a121d..c2dd3d2aa 100644 --- a/rust/lakesoul-io/src/lakesoul_reader.rs +++ b/rust/lakesoul-io/src/lakesoul_reader.rs @@ -2,28 +2,25 @@ // // SPDX-License-Identifier: Apache-2.0 -use atomic_refcell::AtomicRefCell; -use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::physical_plan::SendableRecordBatchStream; use std::sync::Arc; use arrow_schema::SchemaRef; - +use atomic_refcell::AtomicRefCell; pub use datafusion::arrow::error::ArrowError; pub use datafusion::arrow::error::Result as ArrowResult; pub use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::file_format::parquet::ParquetFormat; pub use datafusion::error::{DataFusionError, Result}; - +use datafusion::physical_plan::SendableRecordBatchStream; use datafusion::prelude::SessionContext; - use futures::StreamExt; - use tokio::runtime::Runtime; use tokio::sync::Mutex; use tokio::task::JoinHandle; use crate::datasource::file_format::LakeSoulParquetFormat; use crate::datasource::listing::LakeSoulListingTable; +use crate::datasource::physical_plan::merge::convert_filter; use crate::datasource::physical_plan::merge::prune_filter_and_execute; use crate::lakesoul_io_config::{create_session_context, LakeSoulIOConfig}; @@ -65,13 +62,12 @@ impl LakeSoulReader { .await?; let dataframe = self.sess_ctx.read_table(Arc::new(source))?; - let stream = prune_filter_and_execute( - dataframe, - schema.clone(), + let filters = convert_filter( + &dataframe, self.config.filter_strs.clone(), - self.config.batch_size, - ) - .await?; + self.config.filter_protos.clone(), + )?; + let stream = prune_filter_and_execute(dataframe, schema.clone(), filters, self.config.batch_size).await?; self.schema = Some(stream.schema()); self.stream = Some(stream); @@ -153,16 +149,23 @@ impl SyncSendableMutableLakeSoulReader { #[cfg(test)] mod tests { - use super::*; - use rand::prelude::*; use std::mem::ManuallyDrop; use std::ops::Not; use std::sync::mpsc::sync_channel; + use std::thread; use std::time::Instant; - use tokio::runtime::Builder; use arrow::datatypes::{DataType, Field, Schema}; use arrow::util::pretty::print_batches; + use datafusion::logical_expr::{col, Expr}; + use datafusion_common::ScalarValue; + use rand::prelude::*; + use tokio::runtime::Builder; + use tokio::time::{sleep, Duration}; + + use crate::lakesoul_io_config::LakeSoulIOConfigBuilder; + + use super::*; #[tokio::test] async fn test_reader_local() -> Result<()> { @@ -170,7 +173,7 @@ mod tests { let reader_conf = LakeSoulIOConfigBuilder::new() .with_files(vec![ project_dir.join("../lakesoul-io-java/src/test/resources/sample-parquet-files/part-00000-a9e77425-5fb4-456f-ba52-f821123bd193-c000.snappy.parquet").into_os_string().into_string().unwrap() - ]) + ]) .with_thread_num(1) .with_batch_size(256) .build(); @@ -191,7 +194,7 @@ mod tests { let project_dir = std::env::current_dir()?; let reader_conf = LakeSoulIOConfigBuilder::new() .with_files(vec![ - project_dir.join("../lakesoul-io-java/src/test/resources/sample-parquet-files/part-00000-a9e77425-5fb4-456f-ba52-f821123bd193-c000.snappy.parquet").into_os_string().into_string().unwrap() + project_dir.join("../lakesoul-io-java/src/test/resources/sample-parquet-files/part-00000-a9e77425-5fb4-456f-ba52-f821123bd193-c000.snappy.parquet").into_os_string().into_string().unwrap() ]) .with_thread_num(2) .with_batch_size(11) @@ -274,8 +277,6 @@ mod tests { Ok(()) } - use tokio::time::{sleep, Duration}; - #[tokio::test] async fn test_reader_s3() -> Result<()> { let reader_conf = LakeSoulIOConfigBuilder::new() @@ -307,7 +308,6 @@ mod tests { Ok(()) } - use std::thread; #[test] fn test_reader_s3_blocked() -> Result<()> { let reader_conf = LakeSoulIOConfigBuilder::new() @@ -357,10 +357,6 @@ mod tests { Ok(()) } - use crate::lakesoul_io_config::LakeSoulIOConfigBuilder; - use datafusion::logical_expr::{col, Expr}; - use datafusion_common::ScalarValue; - async fn get_num_rows_of_file_with_filters(file_path: String, filters: Vec) -> Result { let reader_conf = LakeSoulIOConfigBuilder::new() .with_files(vec![file_path]) diff --git a/rust/lakesoul-io/src/repartition/mod.rs b/rust/lakesoul-io/src/repartition/mod.rs index 5ad0864dc..d7acc7e52 100644 --- a/rust/lakesoul-io/src/repartition/mod.rs +++ b/rust/lakesoul-io/src/repartition/mod.rs @@ -3,11 +3,14 @@ // SPDX-License-Identifier: Apache-2.0 use std::{ - any::Any, collections::HashMap, pin::Pin, sync::Arc, task::{Context, Poll} + any::Any, + collections::HashMap, + pin::Pin, + sync::Arc, + task::{Context, Poll}, }; use arrow_schema::SchemaRef; -use datafusion::{physical_expr::physical_exprs_equal, physical_plan::metrics}; use datafusion::{ execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, @@ -21,6 +24,7 @@ use datafusion::{ SendableRecordBatchStream, }, }; +use datafusion::{physical_expr::physical_exprs_equal, physical_plan::metrics}; use datafusion_common::{DataFusionError, Result}; use arrow_array::{builder::UInt64Builder, ArrayRef, RecordBatch}; @@ -80,9 +84,9 @@ impl BatchPartitioner { /// /// The time spent repartitioning will be recorded to `timer` pub fn try_new( - range_partitioning_expr:Vec>, - hash_partitioning: Partitioning, - timer: metrics::Time + range_partitioning_expr: Vec>, + hash_partitioning: Partitioning, + timer: metrics::Time, ) -> Result { let state = match hash_partitioning { Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState { @@ -139,7 +143,8 @@ impl BatchPartitioner { let it: Box> + Send> = { let timer = self.timer.timer(); - let range_arrays = [range_exprs.clone()].concat() + let range_arrays = [range_exprs.clone()] + .concat() .iter() .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows())) .collect::>>()?; @@ -152,8 +157,7 @@ impl BatchPartitioner { hash_buffer.clear(); hash_buffer.resize(batch.num_rows(), 0); - let mut range_buffer = Vec::::new(); - range_buffer.resize(batch.num_rows(), 0); + let mut range_buffer = vec![0; batch.num_rows()]; create_hashes(&hash_arrays, hash_buffer)?; create_hashes(&range_arrays, &mut range_buffer)?; @@ -164,12 +168,12 @@ impl BatchPartitioner { .collect(); for (index, (hash, range_hash)) in hash_buffer.iter().zip(range_buffer).enumerate() { - if !indices[(*hash % *partitions as u32) as usize].contains_key(&range_hash) { - indices[(*hash % *partitions as u32) as usize].insert(range_hash, UInt64Builder::with_capacity(batch.num_rows())); - } - if let Some(entry) = indices[(*hash % *partitions as u32) as usize].get_mut(&range_hash) { + indices[(*hash % *partitions as u32) as usize] + .entry(range_hash) + .or_insert_with(|| UInt64Builder::with_capacity(batch.num_rows())); + if let Some(entry) = indices[(*hash % *partitions as u32) as usize].get_mut(&range_hash) { entry.append_value(index as u64); - } + } } let it = indices @@ -311,16 +315,29 @@ impl RepartitionByRangeAndHashExec { /// Create a new RepartitionExec, that produces output `partitioning`, and /// does not preserve the order of the input (see [`Self::with_preserve_order`] /// for more details) - pub fn try_new(input: Arc, range_partitioning_expr:Vec>, hash_partitioning: Partitioning) -> Result { + pub fn try_new( + input: Arc, + range_partitioning_expr: Vec>, + hash_partitioning: Partitioning, + ) -> Result { if let Some(ordering) = input.output_ordering() { - let lhs = ordering.iter().map(|sort_expr| sort_expr.expr.clone()).collect::>(); - let rhs = [ + let lhs = ordering + .iter() + .map(|sort_expr| sort_expr.expr.clone()) + .collect::>(); + let rhs = [ range_partitioning_expr.clone(), match &hash_partitioning { Partitioning::Hash(hash_exprs, _) => hash_exprs.clone(), - _ => return Err(DataFusionError::Plan(format!("Invalid hash_partitioning={} for RepartitionByRangeAndHashExec", hash_partitioning))), + _ => { + return Err(DataFusionError::Plan(format!( + "Invalid hash_partitioning={} for RepartitionByRangeAndHashExec", + hash_partitioning + ))) + } }, - ].concat(); + ] + .concat(); if physical_exprs_equal(&lhs, &rhs) { return Ok(Self { @@ -332,18 +349,17 @@ impl RepartitionByRangeAndHashExec { abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])), })), metrics: ExecutionPlanMetricsSet::new(), - }) + }); } - } + } Err(DataFusionError::Plan( format!( "Input ordering {:?} mismatch for RepartitionByRangeAndHashExec with range_partitioning_expr={:?}, hash_partitioning={}", - input.output_ordering(), + input.output_ordering(), range_partitioning_expr, hash_partitioning, )) ) - } /// Return the sort expressions that are used to merge @@ -364,11 +380,8 @@ impl RepartitionByRangeAndHashExec { metrics: RepartitionMetrics, context: Arc, ) -> Result<()> { - let mut partitioner = BatchPartitioner::try_new( - range_partitioning, - hash_partitioning, - metrics.repartition_time.clone() - )?; + let mut partitioner = + BatchPartitioner::try_new(range_partitioning, hash_partitioning, metrics.repartition_time.clone())?; // execute the child operator let timer = metrics.fetch_time.timer(); @@ -519,12 +532,11 @@ impl ExecutionPlan for RepartitionByRangeAndHashExec { } fn with_new_children(self: Arc, mut children: Vec>) -> Result> { - let repartition = - RepartitionByRangeAndHashExec::try_new( - children.swap_remove(0), - self.range_partitioning_expr.clone(), - self.hash_partitioning.clone() - )?; + let repartition = RepartitionByRangeAndHashExec::try_new( + children.swap_remove(0), + self.range_partitioning_expr.clone(), + self.hash_partitioning.clone(), + )?; Ok(Arc::new(repartition)) } @@ -546,17 +558,17 @@ impl ExecutionPlan for RepartitionByRangeAndHashExec { // let rxs = transpose(rxs); // (txs, rxs) // } else { - // create one channel per *output* partition - // note we use a custom channel that ensures there is always data for each receiver - // but limits the amount of buffering if required. - let (txs, rxs) = channels(num_output_partitions); - // Clone sender for each input partitions - let txs = txs - .into_iter() - .map(|item| vec![item; num_input_partitions]) - .collect::>(); - let rxs = rxs.into_iter().map(|item| vec![item]).collect::>(); - (txs, rxs) + // create one channel per *output* partition + // note we use a custom channel that ensures there is always data for each receiver + // but limits the amount of buffering if required. + let (txs, rxs) = channels(num_output_partitions); + // Clone sender for each input partitions + let txs = txs + .into_iter() + .map(|item| vec![item; num_input_partitions]) + .collect::>(); + let rxs = rxs.into_iter().map(|item| vec![item]).collect::>(); + (txs, rxs) }; for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { let reservation = Arc::new(Mutex::new( @@ -646,7 +658,7 @@ impl ExecutionPlan for RepartitionByRangeAndHashExec { // fetch, // merge_reservation, // ) - // } else { + // } else { Ok(Box::pin(RepartitionStream { num_input_partitions, num_input_partitions_processed: 0, @@ -655,9 +667,7 @@ impl ExecutionPlan for RepartitionByRangeAndHashExec { drop_helper: Arc::clone(&state.abort_helper), reservation, })) - // } - - + // } } } @@ -685,17 +695,12 @@ struct RepartitionStream { impl Stream for RepartitionStream { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { match self.input.recv().poll_unpin(cx) { Poll::Ready(Some(Some(v))) => { if let Ok(batch) = &v { - self.reservation - .lock() - .shrink(batch.get_array_memory_size()); + self.reservation.lock().shrink(batch.get_array_memory_size()); } return Poll::Ready(Some(v)); @@ -729,7 +734,6 @@ impl RecordBatchStream for RepartitionStream { } } - /// This struct converts a receiver to a stream. /// Receiver receives data on an SPSC channel. struct PerPartitionStream { diff --git a/rust/lakesoul-metadata-c/src/lib.rs b/rust/lakesoul-metadata-c/src/lib.rs index 872edcbd0..7e9249a02 100644 --- a/rust/lakesoul-metadata-c/src/lib.rs +++ b/rust/lakesoul-metadata-c/src/lib.rs @@ -10,15 +10,15 @@ use core::ffi::c_ptrdiff_t; use std::collections::HashMap; use std::ffi::{c_char, c_uchar, CStr, CString}; use std::io::Write; -use std::ptr::{NonNull, null, null_mut}; +use std::ptr::{null, null_mut, NonNull}; use log::debug; use prost::bytes::BufMut; use prost::Message; -use lakesoul_metadata::{Builder, Client, MetaDataClient, PreparedStatementMap, Runtime}; use lakesoul_metadata::error::LakeSoulMetaDataError; use lakesoul_metadata::transfusion::SplitDesc; +use lakesoul_metadata::{Builder, Client, MetaDataClient, PreparedStatementMap, Runtime}; use proto::proto::entity; #[repr(C)] @@ -391,8 +391,11 @@ pub extern "C" fn create_split_desc_array( let (ret, status, e) = match result { Ok(ptr) => (ptr, true, null()), - Err(e) => - (null_mut(), false, CString::new(e.to_string()).unwrap().into_raw() as *const c_char), + Err(e) => ( + null_mut(), + false, + CString::new(e.to_string()).unwrap().into_raw() as *const c_char, + ), }; call_result_callback(callback, status, e); ret @@ -414,7 +417,9 @@ pub extern "C" fn debug(callback: extern "C" fn(bool, *const c_char)) -> *mut c_ primary_keys: vec![], partition_desc: HashMap::new(), table_schema: "".to_string(), - }; 1]; + }; + 1 + ]; let array = lakesoul_metadata::transfusion::SplitDescArray(x); let json_vec = serde_json::to_vec(&array).unwrap(); let c_string = CString::new(json_vec).unwrap(); diff --git a/rust/lakesoul-metadata/src/metadata_client.rs b/rust/lakesoul-metadata/src/metadata_client.rs index c784359cd..b68ad1e75 100644 --- a/rust/lakesoul-metadata/src/metadata_client.rs +++ b/rust/lakesoul-metadata/src/metadata_client.rs @@ -504,12 +504,7 @@ impl MetaDataClient { } } - - pub async fn get_data_files_by_table_name( - &self, - table_name: &str, - namespace: &str, - ) -> Result> { + pub async fn get_data_files_by_table_name(&self, table_name: &str, namespace: &str) -> Result> { let table_info = self.get_table_info_by_table_name(table_name, namespace).await?; debug!("table_info: {:?}", table_info); let partition_list = self.get_all_partition_info(table_info.table_id.as_str()).await?; @@ -521,24 +516,16 @@ impl MetaDataClient { self.get_data_files_of_partitions(partition_list).await } - pub async fn get_data_files_of_partitions( - &self, - partition_list: Vec, - ) -> Result> { + pub async fn get_data_files_of_partitions(&self, partition_list: Vec) -> Result> { let mut data_files = Vec::::new(); for partition_info in &partition_list { let _data_file_list = self.get_data_files_of_single_partition(partition_info).await?; data_files.extend_from_slice(&_data_file_list); - } Ok(data_files) - } - pub async fn get_data_files_of_single_partition( - &self, - partition_info: &PartitionInfo, - ) -> Result> { + pub async fn get_data_files_of_single_partition(&self, partition_info: &PartitionInfo) -> Result> { let data_commit_info_list = self.get_data_commit_info_of_single_partition(partition_info).await?; // let data_commit_info_list = Vec::::new(); let data_file_list = data_commit_info_list @@ -552,10 +539,8 @@ impl MetaDataClient { }) .collect::>(); Ok(data_file_list) - } - async fn get_data_commit_info_of_single_partition( &self, partition_info: &PartitionInfo, diff --git a/rust/lakesoul-metadata/src/transfusion.rs b/rust/lakesoul-metadata/src/transfusion.rs index 70bae9253..f32172cd6 100644 --- a/rust/lakesoul-metadata/src/transfusion.rs +++ b/rust/lakesoul-metadata/src/transfusion.rs @@ -17,12 +17,12 @@ use tokio_postgres::Client; use proto::proto::entity::{DataCommitInfo, DataFileOp, FileOp, JniWrapper, PartitionInfo, TableInfo}; -use crate::{DaoType, error::Result, execute_query, PARAM_DELIM, PreparedStatementMap}; use crate::error::LakeSoulMetaDataError; use crate::transfusion::config::{ LAKESOUL_HASH_PARTITION_SPLITTER, LAKESOUL_NON_PARTITION_TABLE_PART_DESC, LAKESOUL_PARTITION_SPLITTER_OF_RANGE_AND_HASH, LAKESOUL_RANGE_PARTITION_SPLITTER, }; +use crate::{error::Result, execute_query, DaoType, PreparedStatementMap, PARAM_DELIM}; mod config { #![allow(unused)] @@ -117,9 +117,7 @@ pub async fn split_desc_array( None => { return Err(LakeSoulMetaDataError::Internal("split error".to_string())); } - Some((k, v)) => { - (k.to_string(), v.to_string()) - } + Some((k, v)) => (k.to_string(), v.to_string()), }; range_desc.insert(k, v); } @@ -137,7 +135,6 @@ pub async fn split_desc_array( Ok(SplitDescArray(splits)) } - struct RawClient<'a> { client: Mutex<&'a Client>, prepared: Mutex<&'a mut PreparedStatementMap>, @@ -173,8 +170,7 @@ impl<'a> RawClient<'_> { pub async fn get_table_data_info(&self, table_id: &str) -> Result> { // logic from scala: DataOperation let vec = self.get_all_partition_info(table_id).await?; - self.get_table_data_info_by_partition_info(vec) - .await + self.get_table_data_info_by_partition_info(vec).await } async fn get_table_data_info_by_partition_info( @@ -188,7 +184,6 @@ impl<'a> RawClient<'_> { Ok(file_info_buf) } - /// return file info in this partition that match the current read version async fn get_single_partition_data_info(&self, partition_info: &PartitionInfo) -> Result> { let mut file_arr_buf = Vec::new(); @@ -240,7 +235,7 @@ impl<'a> RawClient<'_> { query_type, joined_string.clone(), ) - .await?; + .await?; Ok(JniWrapper::decode(prost::bytes::Bytes::from(encoded))?) }