Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
Signed-off-by: mag1c1an1 <[email protected]>
  • Loading branch information
mag1c1an1 committed Mar 29, 2024
1 parent 1bb2c98 commit 1e49766
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 456 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ public DynamicTableSource copy() {
this.optionParams);
lsts.projectedFields = this.projectedFields;
lsts.remainingPartitions = this.remainingPartitions;
lsts._filterPredicate = this._filterPredicate;
lsts.filter = this.filter;
lsts.modificationContext = this.modificationContext;
return lsts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.table.utils.PartitionPathUtils;
import org.apache.flink.types.RowKind;
import org.apache.parquet.filter2.predicate.FilterPredicate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -74,7 +73,6 @@ public class LakeSoulOneSplitRecordsReader implements RecordsWithSplitIds<RowDat
// arrow batch -> row, with requested schema
private ArrowReader curArrowReaderRequestedSchema;

private final FilterPredicate _filterPredicate;
private final Plan filter;

public LakeSoulOneSplitRecordsReader(Configuration conf,
Expand All @@ -84,7 +82,6 @@ public LakeSoulOneSplitRecordsReader(Configuration conf,
List<String> pkColumns,
boolean isStreaming,
String cdcColumn,
FilterPredicate _filterPredicate,
Plan filter)
throws Exception {
this.split = split;
Expand All @@ -97,7 +94,6 @@ public LakeSoulOneSplitRecordsReader(Configuration conf,
this.isStreaming = isStreaming;
this.cdcColumn = cdcColumn;
this.finishedSplit = Collections.singleton(splitId);
this._filterPredicate = _filterPredicate;
this.filter = filter;
initializeReader();
recoverFromSkipRecord();
Expand Down Expand Up @@ -136,10 +132,6 @@ private void initializeReader() throws IOException {
reader.addFilterProto(this.filter);
}

if (_filterPredicate !=null) {
reader.addFilter(_filterPredicate.toString());
}

LOG.info("Initializing reader for split {}, pk={}, partitions={}," +
" non partition cols={}, cdc column={}, filter={}",
split,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ public LakeSoulSource(TableId tableId,
List<String> pkColumns,
Map<String, String> optionParams,
@Nullable List<Map<String, String>> remainingPartitions,
@Nullable FilterPredicate _filterPredicate,
@Nullable Plan filter) {
this.tableId = tableId;
this.rowType = rowType;
Expand All @@ -64,7 +63,6 @@ public LakeSoulSource(TableId tableId,
this.pkColumns = pkColumns;
this.optionParams = optionParams;
this.remainingPartitions = remainingPartitions;
this._filterPredicate = _filterPredicate;
this.filter = filter;

}
Expand All @@ -89,7 +87,6 @@ public SourceReader<RowData, LakeSoulSplit> createReader(SourceReaderContext rea
this.pkColumns,
this.isStreaming,
this.optionParams.getOrDefault(LakeSoulSinkOptions.CDC_CHANGE_COLUMN, ""),
this._filterPredicate,
this.filter),
new LakeSoulRecordEmitter(),
readerContext.getConfiguration(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.flink.connector.base.source.reader.splitreader.SplitsChange;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.types.logical.RowType;
import org.apache.parquet.filter2.predicate.FilterPredicate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -40,7 +39,6 @@ public class LakeSoulSplitReader implements SplitReader<RowData, LakeSoulSplit>

String cdcColumn;

FilterPredicate _filterPredicate;
Plan filter;

private LakeSoulOneSplitRecordsReader lastSplitReader;
Expand All @@ -51,7 +49,6 @@ public LakeSoulSplitReader(Configuration conf,
List<String> pkColumns,
boolean isStreaming,
String cdcColumn,
FilterPredicate _filterPredicate,
Plan filter) {
this.conf = conf;
this.splits = new ArrayDeque<>();
Expand All @@ -60,7 +57,6 @@ public LakeSoulSplitReader(Configuration conf,
this.pkColumns = pkColumns;
this.isStreaming = isStreaming;
this.cdcColumn = cdcColumn;
this._filterPredicate = _filterPredicate;
this.filter = filter;
}

Expand All @@ -76,7 +72,6 @@ public RecordsWithSplitIds<RowData> fetch() throws IOException {
this.pkColumns,
this.isStreaming,
this.cdcColumn,
this._filterPredicate,
this.filter
);
return lastSplitReader;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,12 @@ public class SubstraitFlinkUtil {
public static Tuple2<SupportsFilterPushDown.Result, io.substrait.proto.Plan> flinkExprToSubStraitPlan(
List<ResolvedExpression> exprs,
List<ResolvedExpression> remaining,
String tableName,
String tableSchema
) throws IOException {
String tableName
) {
List<ResolvedExpression> accepted = new ArrayList<>();
Schema arrowSchema = Schema.fromJSON(tableSchema);
Expression last = null;
for (ResolvedExpression expr : exprs) {
Expression e = doTransform(expr,arrowSchema);
Expression e = doTransform(expr);
if (e == null) {
remaining.add(expr);
} else {
Expand All @@ -44,12 +42,12 @@ public static Tuple2<SupportsFilterPushDown.Result, io.substrait.proto.Plan> fli
}
}
}
Plan filter = exprToFilter(last, tableName, arrowSchema);
Plan filter = exprToFilter(last, tableName);
return Tuple2.of(SupportsFilterPushDown.Result.of(accepted, remaining), planToProto(filter));
}

public static Expression doTransform(ResolvedExpression flinkExpression, Schema arrowSchema) {
SubstraitVisitor substraitVisitor = new SubstraitVisitor(arrowSchema);
public static Expression doTransform(ResolvedExpression flinkExpression) {
SubstraitVisitor substraitVisitor = new SubstraitVisitor();
return flinkExpression.accept(substraitVisitor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,33 @@
import io.substrait.extension.SimpleExtension;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.flink.table.expressions.*;
import org.apache.flink.table.expressions.ExpressionVisitor;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.spark.sql.catalyst.util.DateTimeUtils$;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Timestamp;
import java.time.Instant;
import java.math.BigDecimal;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Function;


/**
* return null means cannot convert
*/
public class SubstraitVisitor implements ExpressionVisitor<Expression> {

public SubstraitVisitor(Schema arrowSchema) {
this.arrowSchema = arrowSchema;
}


private static final Logger LOG = LoggerFactory.getLogger(SubstraitVisitor.class);

private Schema arrowSchema;

@Override
public Expression visit(CallExpression call) {
CallExprVisitor callVisitor = new CallExprVisitor(this.arrowSchema);
CallExprVisitor callVisitor = new CallExprVisitor();
return callVisitor.visit(call);
}

Expand All @@ -56,7 +45,7 @@ public Expression visit(ValueLiteralExpression valueLiteral) {

@Override
public Expression visit(FieldReferenceExpression fieldReference) {
return new FieldRefVisitor(this.arrowSchema).visit(fieldReference);
return new FieldRefVisitor().visit(fieldReference);
}

@Override
Expand Down Expand Up @@ -164,6 +153,14 @@ public Expression.Literal visit(ValueLiteralExpression valueLiteral) {
}
return ExpressionCreator.fp64(nullable, d);
}
case DECIMAL: {
BigDecimal bigDecimal = new BigDecimal(0);
DecimalType dt = (DecimalType) logicalType;
if (value != null) {
bigDecimal = (BigDecimal) value;
}
return ExpressionCreator.decimal(nullable, bigDecimal, dt.getPrecision(), dt.getScale());
}
case DATE: {
int days = 0;
if (value != null) {
Expand Down Expand Up @@ -217,12 +214,6 @@ protected Expression.Literal defaultMethod(org.apache.flink.table.expressions.Ex
}

class FieldRefVisitor extends ExpressionDefaultVisitor<FieldReference> {
public FieldRefVisitor(Schema arrow_schema) {
this.arrow_schema = arrow_schema;
}

private Schema arrow_schema;

private static final Logger LOG = LoggerFactory.getLogger(FieldRefVisitor.class);

public FieldReference visit(FieldReferenceExpression fieldReference) {
Expand All @@ -232,29 +223,16 @@ public FieldReference visit(FieldReferenceExpression fieldReference) {
fieldReference = (FieldReferenceExpression) fieldReference.getChildren().get(0);
}
LogicalType logicalType = fieldReference.getOutputDataType().getLogicalType();
LogicalTypeRoot typeRoot = logicalType.getTypeRoot();
Type type = mapType(typeRoot, logicalType.isNullable());
Type type = mapType(logicalType);
if (type == null) {
// not supported
return null;
}
String name = fieldReference.getName();
int idx = 0;
List<Field> fields = arrow_schema.getFields();
// find idx
for (int i = 0; i < fields.size(); i++) {
if (fields.get(i).getName().equals(name)) {
idx = i;
break;
}
}

return FieldReference.builder()
.type(Objects.requireNonNull(type))
.addSegments(
ImmutableStructField.builder()
.offset(idx)
.build()
ImmutableMapKey.of(ExpressionCreator.string(true, name))
)
.build();
}
Expand All @@ -264,7 +242,9 @@ protected FieldReference defaultMethod(org.apache.flink.table.expressions.Expres
return null;
}

public static Type mapType(LogicalTypeRoot typeRoot, Boolean nullable) {
public static Type mapType(LogicalType logicalType) {
LogicalTypeRoot typeRoot = logicalType.getTypeRoot();
boolean nullable = logicalType.isNullable();
TypeCreator R = TypeCreator.of(nullable);
switch (typeRoot) {
case CHAR:
Expand Down Expand Up @@ -295,6 +275,10 @@ public static Type mapType(LogicalTypeRoot typeRoot, Boolean nullable) {
case DOUBLE: {
return R.FP64;
}
case DECIMAL: {
DecimalType dt = (DecimalType) logicalType;
return R.decimal(dt.getPrecision(), dt.getScale());
}
case DATE: {
return R.DATE;
}
Expand All @@ -315,28 +299,22 @@ public static Type mapType(LogicalTypeRoot typeRoot, Boolean nullable) {
}

class CallExprVisitor extends ExpressionDefaultVisitor<Expression> {

public CallExprVisitor(Schema arrowSchema) {
this.arrowSchema = arrowSchema;
}

private Schema arrowSchema;
private static final Logger LOG = LoggerFactory.getLogger(CallExprVisitor.class);
private static final ImmutableMap<FunctionDefinition, BiFunction<CallExpression, Schema, Expression>>
private static final ImmutableMap<FunctionDefinition, Function<CallExpression, Expression>>
FILTERS =
new ImmutableMap.Builder<
FunctionDefinition, BiFunction<CallExpression, Schema, Expression>>()
.put(BuiltInFunctionDefinitions.IS_NULL, (call, schema) -> makeUnaryFunction(call, schema, "is_null:any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.IS_NOT_NULL, (call, schema) -> makeUnaryFunction(call, schema, "is_not_null:any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.NOT, (call, schema) -> makeUnaryFunction(call, schema, "not:bool", SubstraitUtil.BooleanNamespace))
.put(BuiltInFunctionDefinitions.OR, (call, schema) -> makeBinaryFunction(call, schema, "or:bool", SubstraitUtil.BooleanNamespace))
.put(BuiltInFunctionDefinitions.AND, (call, schema) -> makeBinaryFunction(call, schema, "and:bool", SubstraitUtil.BooleanNamespace))
.put(BuiltInFunctionDefinitions.EQUALS, (call, schema) -> makeBinaryFunction(call, schema, "equal:any_any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.NOT_EQUALS, (call, schema) -> makeBinaryFunction(call, schema, "not_equal:any_any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.GREATER_THAN, (call, schema) -> makeBinaryFunction(call, schema, "gt:any_any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL, (call, schema) -> makeBinaryFunction(call, schema, "gte:any_any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.LESS_THAN, (call, schema) -> makeBinaryFunction(call, schema, "lt:any_any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.LESS_THAN_OR_EQUAL, (call, schema) -> makeBinaryFunction(call, schema, "lte:any_any", SubstraitUtil.CompNamespace))
FunctionDefinition, Function<CallExpression, Expression>>()
.put(BuiltInFunctionDefinitions.IS_NULL, call -> makeUnaryFunction(call, "is_null:any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.IS_NOT_NULL, call -> makeUnaryFunction(call, "is_not_null:any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.NOT, call -> makeUnaryFunction(call, "not:bool", SubstraitUtil.BooleanNamespace))
.put(BuiltInFunctionDefinitions.OR, call -> makeBinaryFunction(call, "or:bool", SubstraitUtil.BooleanNamespace))
.put(BuiltInFunctionDefinitions.AND, call -> makeBinaryFunction(call, "and:bool", SubstraitUtil.BooleanNamespace))
.put(BuiltInFunctionDefinitions.EQUALS, call -> makeBinaryFunction(call, "equal:any_any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.NOT_EQUALS, call -> makeBinaryFunction(call, "not_equal:any_any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.GREATER_THAN, call -> makeBinaryFunction(call, "gt:any_any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL, call -> makeBinaryFunction(call, "gte:any_any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.LESS_THAN, call -> makeBinaryFunction(call, "lt:any_any", SubstraitUtil.CompNamespace))
.put(BuiltInFunctionDefinitions.LESS_THAN_OR_EQUAL, call -> makeBinaryFunction(call, "lte:any_any", SubstraitUtil.CompNamespace))
.build();

@Override
Expand All @@ -348,13 +326,13 @@ public Expression visit(CallExpression call) {
call);
return null;
}
return FILTERS.get(call.getFunctionDefinition()).apply(call, this.arrowSchema);
return FILTERS.get(call.getFunctionDefinition()).apply(call);
}

static Expression makeBinaryFunction(CallExpression call, Schema arrow_schema, String funcKey, String namespace) {
static Expression makeBinaryFunction(CallExpression call, String funcKey, String namespace) {
List<org.apache.flink.table.expressions.Expression> children = call.getChildren();
assert children.size() == 2;
SubstraitVisitor visitor = new SubstraitVisitor(arrow_schema);
SubstraitVisitor visitor = new SubstraitVisitor();
Expression left = children.get(0).accept(visitor);
Expression right = children.get(1).accept(visitor);
if (left == null || right == null) {
Expand All @@ -367,10 +345,10 @@ static Expression makeBinaryFunction(CallExpression call, Schema arrow_schema, S
return ExpressionCreator.scalarFunction(func, TypeCreator.NULLABLE.BOOLEAN, args);
}

static Expression makeUnaryFunction(CallExpression call, Schema arrow_schema, String funcKey, String namespace) {
static Expression makeUnaryFunction(CallExpression call, String funcKey, String namespace) {
List<org.apache.flink.table.expressions.Expression> children = call.getChildren();
assert children.size() == 1;
SubstraitVisitor visitor = new SubstraitVisitor(arrow_schema);
SubstraitVisitor visitor = new SubstraitVisitor();
Expression child = children.get(0).accept(visitor);
if (child == null) {
return null;
Expand Down
Loading

0 comments on commit 1e49766

Please sign in to comment.