Skip to content

Commit

Permalink
[BugFix] RangePredicateExtractor discard predicates mistakenly (backp…
Browse files Browse the repository at this point in the history
…ort #50854)

Signed-off-by: satanson <[email protected]>
  • Loading branch information
satanson committed Sep 10, 2024
1 parent 5c2763f commit 8d5581f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu

Set<ScalarOperator> conjuncts = Sets.newLinkedHashSet();
conjuncts.addAll(Utils.extractConjuncts(predicate));
predicate = Utils.compoundAnd(conjuncts);

Map<ScalarOperator, ValueDescriptor> extractMap = extractImpl(predicate);

Expand All @@ -90,7 +91,6 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu
return predicate;
}

predicate = Utils.compoundAnd(Lists.newArrayList(conjuncts));
if (isOnlyOrCompound(predicate)) {
Set<ColumnRefOperator> c = Sets.newHashSet(Utils.extractColumnRef(predicate));
if (c.size() == extractMap.size() &&
Expand All @@ -103,6 +103,17 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu
List<ScalarOperator> cs = Utils.extractConjuncts(predicate);
Set<ColumnRefOperator> cf = new HashSet<>(Utils.extractColumnRef(predicate));

// getSourceCount = cs.size() means that all and components have the same column ref
// and it can be merged into one range predicate. mistakenly, when the predicate is
// date_trunc(YEAR, dt) = '2024-01-01' AND mode = 'Buzz' AND
// date_trunc(YEAR, dt) = '2024-01-01' AND mode = 'Buzz' //duplicate
//
// only mode = 'Buzz' is a extractable range predicate, so its corresponding ValueDescriptor's
// sourceCount = 2(since it occurs twice), and cs(it is also 2 in this example)are number of
// unique column refs of the predicate, the two values are properly equivalent, so it yields
// wrong result.
//
// Components of AND/OR should be deduplicated at first to avoid this issue.
if (extractMap.values().stream().allMatch(valueDescriptor -> valueDescriptor.sourceCount == cs.size())
&& extractMap.size() == cf.size()) {
if (result.size() == conjuncts.size()) {
Expand Down Expand Up @@ -476,4 +487,4 @@ private static boolean isOnlyOrCompound(ScalarOperator predicate) {
return true;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@

package com.starrocks.sql.optimizer.rewrite;

import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.collect.Lists;
import com.starrocks.analysis.BinaryType;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.scalar.BetweenPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator;
Expand All @@ -32,8 +37,13 @@
import com.starrocks.sql.optimizer.rewrite.scalar.NormalizePredicateRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ReduceCastRule;
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedPredicateRule;
import org.junit.Assert;
import org.junit.Test;

import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.stream.Collectors;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

Expand Down Expand Up @@ -133,4 +143,32 @@ public void testNormalizeIsNull() {
.rewrite(isnotNull, ScalarOperatorRewriter.DEFAULT_REWRITE_SCAN_PREDICATE_RULES);
assertEquals(ConstantOperator.TRUE, rewritten2);
}

@Test
public void testRangeExtract() {
Supplier<ScalarOperator> predicate1Maker = () -> {
ColumnRefOperator col1 = new ColumnRefOperator(1, Type.DATE, "dt", false);
CallOperator call = new CallOperator(FunctionSet.DATE_TRUNC, Type.DATE,
Arrays.asList(ConstantOperator.createVarchar("YEAR"), col1));
return new BinaryPredicateOperator(BinaryType.EQ, call, ConstantOperator.createDate(
LocalDateTime.of(2024, 1, 1, 0, 0, 0)));
};

Supplier<ScalarOperator> predicate2Maker = () -> {
ColumnRefOperator col2 = new ColumnRefOperator(2, Type.VARCHAR, "mode", false);
return new BinaryPredicateOperator(BinaryType.EQ, col2, ConstantOperator.createVarchar("Buzz"));
};

ScalarOperator predicate1 = Utils.compoundAnd(predicate1Maker.get(), predicate2Maker.get());
ScalarOperator predicate2 = Utils.compoundAnd(predicate1Maker.get(), predicate2Maker.get());
ScalarOperator predicates = Utils.compoundAnd(predicate1, predicate2);

ScalarRangePredicateExtractor rangeExtractor = new ScalarRangePredicateExtractor();
ScalarOperator result = rangeExtractor.rewriteOnlyColumn(Utils.compoundAnd(Utils.extractConjuncts(predicates)
.stream().map(rangeExtractor::rewriteOnlyColumn).collect(Collectors.toList())));
Preconditions.checkState(result != null);
String expect = "date_trunc(YEAR, 1: dt) = 2024-01-01 AND 2: mode = Buzz";
String actual = result.toString();
Assert.assertEquals(actual, expect, actual);
}
}
25 changes: 16 additions & 9 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/SubqueryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1798,10 +1798,10 @@ public void testCorrelatedPredicateRewrite_1() throws Exception {

String sql = "select v1 from t0 where v1 = 1 or v2 in (select v4 from t1 where v2 = v4 and v5 = 1)";
String plan = getFragmentPlan(sql);
assertContains(plan, "15:AGGREGATE (merge finalize)\n" +
" | group by: 8: v4\n" +
" | \n" +
" 14:EXCHANGE");
assertContains(plan, " |----15:AGGREGATE (merge finalize)\n" +
" | | group by: 8: v4\n" +
" | | \n" +
" | 14:EXCHANGE");
assertContains(plan, "9:HASH JOIN\n" +
" | join op: RIGHT OUTER JOIN (BUCKET_SHUFFLE(S))\n" +
" | colocate: false, reason: \n" +
Expand All @@ -1822,16 +1822,23 @@ public void testCorrelatedPredicateRewrite_2() throws Exception {
"or v2 in (select v4 from t1 where v2 = v4 and v5 = 1)";

String plan = getFragmentPlan(sql);
assertContains(plan, "33:AGGREGATE (merge finalize)\n" +
" | group by: 12: v4\n" +
assertContains(plan, " |----32:AGGREGATE (merge finalize)\n" +
" | | group by: 12: v4\n" +
" | | \n" +
" | 31:EXCHANGE");
assertContains(plan, " 27:Project\n" +
" | <slot 1> : 1: v1\n" +
" | <slot 2> : 2: v2\n" +
" | <slot 7> : 7: expr\n" +
" | <slot 15> : 15: countRows\n" +
" | <slot 16> : 16: countNotNulls\n" +
" | \n" +
" 32:EXCHANGE");
assertContains(plan, "27:HASH JOIN\n" +
" 26:HASH JOIN\n" +
" | join op: RIGHT OUTER JOIN (BUCKET_SHUFFLE(S))\n" +
" | colocate: false, reason: \n" +
" | equal join conjunct: 14: v4 = 2: v2\n" +
" | \n" +
" |----26:EXCHANGE\n" +
" |----25:EXCHANGE\n" +
" | \n" +
" 6:AGGREGATE (merge finalize)\n" +
" | output: count(15: countRows), count(16: countNotNulls)\n" +
Expand Down

0 comments on commit 8d5581f

Please sign in to comment.