Skip to content

Commit

Permalink
[NativeIO] Doris filter support (#507)
Browse files Browse the repository at this point in the history
* support doris filter pushdown

Signed-off-by: zenghua <[email protected]>

* add cdcColumnMergeOnReadFilter

Signed-off-by: zenghua <[email protected]>

* fix ci

Signed-off-by: zenghua <[email protected]>

* fix ci

Signed-off-by: zenghua <[email protected]>

* fix arrow shaded pattern

Signed-off-by: zenghua <[email protected]>

---------

Signed-off-by: zenghua <[email protected]>
Co-authored-by: zenghua <[email protected]>
  • Loading branch information
Ceng23333 and zenghua authored Jul 17, 2024
1 parent 21265d8 commit a3c9557
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

package com.dmetasoul.lakesoul.meta

import com.dmetasoul.lakesoul.meta.entity.DataCommitInfo
import com.dmetasoul.lakesoul.meta.entity.{DataCommitInfo, PartitionInfo}
import org.apache.hadoop.fs.Path

import java.util
import java.util.{Objects, UUID}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -59,6 +60,10 @@ object DataOperation {

val dbManager = new DBManager

def getTableDataInfo(partitionList: util.List[PartitionInfo]): Array[DataFileInfo] = {
getTableDataInfo(MetaVersion.convertPartitionInfoScala(partitionList))
}

def getTableDataInfo(tableId: String): Array[DataFileInfo] = {
getTableDataInfo(MetaVersion.getAllPartitionInfoScala(tableId))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ object MetaVersion {
dbManager.getAllPartitionInfo(table_id)
}

def getAllPartitionInfoScala(table_id: String): Array[PartitionInfoScala] = {
def convertPartitionInfoScala(partitionList: util.List[PartitionInfo]): Array[PartitionInfoScala] = {
val partitionVersionBuffer = new ArrayBuffer[PartitionInfoScala]()
val res_itr = getAllPartitionInfo(table_id).iterator()
val res_itr = partitionList.iterator()
while (res_itr.hasNext) {
val res = res_itr.next()
partitionVersionBuffer += PartitionInfoScala(
Expand All @@ -206,6 +206,10 @@ object MetaVersion {
partitionVersionBuffer.toArray
}

def getAllPartitionInfoScala(table_id: String): Array[PartitionInfoScala] = {
convertPartitionInfoScala(getAllPartitionInfo(table_id))
}

def rollbackPartitionInfoByVersion(table_id: String, range_value: String, toVersion: Int): Unit = {
dbManager.rollbackPartitionByVersion(table_id, range_value, toVersion);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ public static CatalogBaseTable toFlinkCatalog(TableInfo tableInfo) {

List<String> parKeys = partitionKeys.rangeKeys;
HashMap<String, String> conf = new HashMap<>();
properties.forEach((key, value) -> conf.put(key, (String) value));
properties.forEach((key, value) -> conf.put(key, value.toString()));
if (FlinkUtil.isView(tableInfo)) {
return CatalogView.of(bd.build(), "", properties.getString(VIEW_ORIGINAL_QUERY),
properties.getString(VIEW_EXPANDED_QUERY), conf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@

public class SubstraitTest extends AbstractTestBase {

public static void main(String[] args) {
TableEnvironment createTableEnv = TestUtils.createTableEnv(BATCH_TYPE);
createTableEnv.executeSql("select * from nation where n_regionkey = 0 or n_nationkey > 14").print();
}

@Test
public void loadSubStrait() throws IOException {
SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.loadDefaults();
Expand Down
48 changes: 37 additions & 11 deletions native-io/lakesoul-io-java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ SPDX-License-Identifier: Apache-2.0
<artifactId>scala-reflect</artifactId>
<version>${scala.version}</version>
</dependency>


<!-- jnr-ffi deps-->
<dependency>
Expand Down Expand Up @@ -607,8 +607,20 @@ SPDX-License-Identifier: Apache-2.0
<exclude>org/apache/arrow/c/jni/JniLoader.class</exclude>
</excludes>
</filter>
<filter>
<!-- 包含所有Scala相关资源 -->
<artifact>**/*scala*</artifact>
<includes>
<include>**/*.class</include>
<!-- 如果还有其他Scala相关的资源,比如*.sbt、*.scala等 -->
</includes>
</filter>
</filters>
<relocations>
<relocation>
<pattern>scala.</pattern>
<shadedPattern>com.lakesoul.shaded.scala.</shadedPattern>
</relocation>
<relocation>
<pattern>org.apache.spark.sql</pattern>
<shadedPattern>com.lakesoul.shaded.org.apache.spark.sql</shadedPattern>
Expand All @@ -626,24 +638,38 @@ SPDX-License-Identifier: Apache-2.0
<shadedPattern>com.lakesoul.shaded.com.google.flatbuffers</shadedPattern>
</relocation>
<relocation>
<pattern>org.apache.arrow.flatbuf</pattern>
<shadedPattern>io.glutenproject.shaded.org.apache.arrow.flatbuf</shadedPattern>
<pattern>org.apache.arrow</pattern>
<shadedPattern>com.lakesoul.shaded.org.apache.arrow</shadedPattern>
<excludes>
<exclude>org.apache.arrow.c.jni.*</exclude>
<exclude>org.apache.arrow.c.Data*</exclude>
<exclude>org.apache.arrow.c.Format*</exclude>
<exclude>org.apache.arrow.c.Flags*</exclude>
<exclude>org.apache.arrow.c.Metadata*</exclude>
<exclude>org.apache.arrow.c.NativeUtil*</exclude>
<exclude>org.apache.arrow.c.ReferenceCountedArrowArray*</exclude>
<exclude>org.apache.arrow.c.SchemaImporter*</exclude>
<exclude>org.apache.arrow.c.ArrowArrayStreamReader*</exclude>
<exclude>org.apache.arrow.c.ArrayStreamExporter*</exclude>
<exclude>org.apache.arrow.c.ArrayExporter*</exclude>
<exclude>org.apache.arrow.c.SchemaExporter*</exclude>
</excludes>
</relocation>
<relocation>
<pattern>org.apache.arrow.memory</pattern>
<shadedPattern>io.glutenproject.shaded.org.apache.arrow.memory</shadedPattern>
<pattern>com.google.protobuf</pattern>
<shadedPattern>com.lakesoul.shaded.com.google.protobuf</shadedPattern>
</relocation>
<relocation>
<pattern>org.apache.arrow.util</pattern>
<shadedPattern>io.glutenproject.shaded.org.apache.arrow.util</shadedPattern>
<pattern>com.fasterxml.jackson</pattern>
<shadedPattern>com.lakesoul.shaded.com.fasterxml.jackson</shadedPattern>
</relocation>
<relocation>
<pattern>org.apache.arrow.vector</pattern>
<shadedPattern>io.glutenproject.shaded.org.apache.arrow.vector</shadedPattern>
<pattern>org.yaml.snakeyaml</pattern>
<shadedPattern>com.lakesoul.shaded.org.yaml.snakeyaml</shadedPattern>
</relocation>
<relocation>
<pattern>com.google.protobuf</pattern>
<shadedPattern>com.lakesoul.shaded.com.google.protobuf</shadedPattern>
<pattern>org.antlr</pattern>
<shadedPattern>com.lakesoul.shaded.org.antlr</shadedPattern>
</relocation>
</relocations>
</configuration>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.dmetasoul.lakesoul.lakesoul.io.substrait;

import com.dmetasoul.lakesoul.lakesoul.io.DateTimeUtils;
import com.dmetasoul.lakesoul.lakesoul.io.NativeIOBase;
import com.dmetasoul.lakesoul.lakesoul.io.jnr.JnrLoader;
import com.dmetasoul.lakesoul.lakesoul.io.jnr.LibLakeSoulIO;
Expand All @@ -12,7 +13,10 @@
import io.substrait.expression.Expression;

import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableMapKey;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.plan.Plan;
import io.substrait.plan.PlanProtoConverter;
Expand All @@ -24,16 +28,21 @@
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.CDataDictionaryProvider;
import org.apache.arrow.c.Data;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.types.IntervalUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.spark.sql.catalyst.util.DateTimeUtils$;

import java.io.IOException;
import java.math.BigDecimal;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
Expand All @@ -42,13 +51,19 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static io.substrait.extension.DefaultExtensionCatalog.FUNCTIONS_BOOLEAN;

public class SubstraitUtil {
public static final SimpleExtension.ExtensionCollection EXTENSIONS;
public static final SubstraitBuilder BUILDER;

public static final String CompNamespace = "/functions_comparison.yaml";
public static final String BooleanNamespace = "/functions_boolean.yaml";

public static final Expression CONST_TRUE = ExpressionCreator.bool(false, true);

public static final Expression CONST_FALSE = ExpressionCreator.bool(false, false);

private static final LibLakeSoulIO LIB;

private static final Pointer BUFFER1;
Expand All @@ -74,15 +89,23 @@ public class SubstraitUtil {
}

public static Expression and(Expression left, Expression right) {
SimpleExtension.FunctionAnchor fa = SimpleExtension.FunctionAnchor.of(BooleanNamespace, "and:bool");
return ExpressionCreator.scalarFunction(EXTENSIONS.getScalarFunction(fa), TypeCreator.NULLABLE.BOOLEAN, left, right);
return makeBinary(left, right, FUNCTIONS_BOOLEAN, "and:bool", TypeCreator.NULLABLE.BOOLEAN);
}

public static Expression or(Expression left, Expression right) {
return makeBinary(left, right, FUNCTIONS_BOOLEAN, "or:bool", TypeCreator.NULLABLE.BOOLEAN);
}

public static Expression not(Expression expression) {
SimpleExtension.FunctionAnchor fa = SimpleExtension.FunctionAnchor.of(BooleanNamespace, "not:bool");
return ExpressionCreator.scalarFunction(EXTENSIONS.getScalarFunction(fa), TypeCreator.NULLABLE.BOOLEAN, expression);
}

public static Expression makeBinary(Expression left, Expression right, String namespace, String funcKey, Type type) {
SimpleExtension.FunctionAnchor fa = SimpleExtension.FunctionAnchor.of(namespace, funcKey);
return ExpressionCreator.scalarFunction(EXTENSIONS.getScalarFunction(fa), type, left, right);
}

public static io.substrait.proto.Plan substraitExprToProto(Expression e, String tableName) {
return planToProto(exprToFilter(e, tableName));
}
Expand Down Expand Up @@ -120,6 +143,14 @@ public static io.substrait.proto.Plan planToProto(Plan plan) {
return new PlanProtoConverter().toProto(plan);
}

public static String encodeBase64String(io.substrait.proto.Plan plan) {
return Base64.getEncoder().encodeToString(plan.toByteArray());
}

public static io.substrait.proto.Plan decodeBase64String(String base64) throws InvalidProtocolBufferException {
return io.substrait.proto.Plan.parseFrom(Base64.getDecoder().decode(base64));
}

public static List<PartitionInfo> applyPartitionFilters(List<PartitionInfo> allPartitionInfo, Schema schema, io.substrait.proto.Plan partitionFilter) {
if (allPartitionInfo.isEmpty()) {
return Collections.emptyList();
Expand Down Expand Up @@ -199,6 +230,16 @@ public static List<PartitionInfo> applyPartitionFilters(List<PartitionInfo> allP
return resultPartitionInfo;
}


public static FieldReference arrowFieldToSubstraitField(Field field) {
return FieldReference
.builder()
.type(arrowFieldToSubstraitType(field)).addSegments(
ImmutableMapKey.of(ExpressionCreator.string(true, field.getName()))
)
.build();
}

public static Type arrowFieldToSubstraitType(Field field) {
Type type = null;
if (field.getType() instanceof ArrowType.Struct) {
Expand Down Expand Up @@ -351,5 +392,76 @@ public Type visit(ArrowType.Duration duration) {
return null;
}
}

public static Expression anyToSubstraitLiteral(Type type, Object any) throws IOException {
if (type instanceof Type.Date) {
if (any instanceof Integer) {
return ExpressionCreator.date(false, (Integer) any);
} else if (any instanceof Date || any instanceof LocalDate) {
return ExpressionCreator.date(false, DateTimeUtils$.MODULE$.anyToDays(any));
}
}
if (type instanceof Type.Timestamp) {
if (any instanceof Long) {
return ExpressionCreator.timestamp(false, (Long) any);
} else if (any instanceof LocalDateTime || any instanceof Timestamp || any instanceof Instant) {
return ExpressionCreator.timestamp(false, DateTimeUtils.toMicros(any));
}
}
if (type instanceof Type.TimestampTZ) {
if (any instanceof Long) {
return ExpressionCreator.timestampTZ(false, (Long) any);
} else if (any instanceof LocalDateTime || any instanceof Timestamp || any instanceof Instant) {
return ExpressionCreator.timestampTZ(false, DateTimeUtils.toMicros(any));
}
}
if (type instanceof Type.Str || any instanceof String) {
return ExpressionCreator.string(false, (String) any);
}
if (type instanceof Type.Bool || any instanceof Boolean) {
return ExpressionCreator.bool(false, (Boolean) any);
}
if (type instanceof Type.Binary || any instanceof byte[]) {
return ExpressionCreator.binary(false, (byte[]) any);
}

if (type instanceof Type.I8 || any instanceof Byte) {
return ExpressionCreator.i8(false, Byte.parseByte(any.toString()));
}
if (type instanceof Type.I16 || any instanceof Short) {
return ExpressionCreator.i16(false, Short.parseShort(any.toString()));
}
if (type instanceof Type.I32) {
return ExpressionCreator.i32(false, Integer.parseInt(any.toString()));
}
if (type instanceof Type.I64) {
return ExpressionCreator.i64(false, Long.parseLong(any.toString()));
}
if (type instanceof Type.FP32 || any instanceof Float) {
return ExpressionCreator.fp32(false, (Float) any);
}
if (type instanceof Type.FP64 || any instanceof Double) {
return ExpressionCreator.fp64(false, (Double) any);
}
if (type instanceof Type.Decimal || any instanceof BigDecimal) {
int precision = 10;
int scale = 0;
if (type != null) {
precision = ((Type.Decimal) type).precision();
scale = ((Type.Decimal) type).scale();
}
return ExpressionCreator.decimal(false, (BigDecimal) any, precision, scale);
}


throw new IOException("Fail convert to SubstraitLiteral for " + any.toString());
}

public static Expression cdcColumnMergeOnReadFilter(Field field) {
Preconditions.checkArgument(field.getType() instanceof ArrowType.Utf8);
FieldReference fieldReference = arrowFieldToSubstraitField(field);
Expression literal = ExpressionCreator.string(false, "delete");
return makeBinary(fieldReference, literal, DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "not_equal:any_any", TypeCreator.REQUIRED.STRING);
}
}

3 changes: 2 additions & 1 deletion rust/lakesoul-datafusion/src/lakesoul_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ impl LakeSoulTable {
.build()?;
let dataframe = DataFrame::new(sess_ctx.state(), logical_plan);

let _results = dataframe
let results = dataframe
// .explain(true, false)?
.collect()
.await?;
// dbg!(&results);

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion rust/lakesoul-datafusion/src/test/benchmarks/tpch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub fn get_tbl_tpch_table_range_partitions(table: &str) -> Vec<String> {
}
}

/// The `.tbl` file contains a trailing column
/// The `.tbl` file contains a tailing column
pub fn get_tbl_tpch_table_schema(table: &str) -> Schema {
let mut schema = SchemaBuilder::from(get_tpch_table_schema(table).fields);
schema.push(Field::new("__placeholder", DataType::Utf8, true));
Expand Down
25 changes: 25 additions & 0 deletions rust/lakesoul-datafusion/src/test/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,29 @@ mod integration_tests {

Ok(())
}

use lakesoul_io::lakesoul_reader::LakeSoulReader;

#[tokio::test]
async fn debug() -> Result<()> {
let config =
LakeSoulIOConfigBuilder::new()
.with_schema(Arc::new(get_tbl_tpch_table_schema("nation")))
.with_files(vec!["/Users/ceng/Documents/GitHub/LakeSoul/rust/lakesoul-datafusion/default/nation/n_regionkey=0/part-4vqnoXvFFTInJqDV_0000.parquet".to_string()])
.with_default_column_value("n_regionkey".to_string(), "0".to_string())
.build();
let mut reader = LakeSoulReader::new(config)?;
reader.start().await?;

let mut row_cnt = 0;
while let Some(rb) = reader.next_rb().await {
let rb = rb.unwrap();
dbg!(&rb);
let num_rows = rb.num_rows();
row_cnt += num_rows;
}
dbg!(row_cnt);

Ok(())
}
}

0 comments on commit a3c9557

Please sign in to comment.