Skip to content

Commit

Permalink
[BugFix] RangePredicateExtractor discard predicates mistakenly (backp…
Browse files Browse the repository at this point in the history
…ort #50854) (#50906)

Signed-off-by: satanson <[email protected]>
  • Loading branch information
satanson authored Sep 19, 2024
1 parent 3164f1d commit bcc1c1a
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu

Set<ScalarOperator> conjuncts = Sets.newLinkedHashSet();
conjuncts.addAll(Utils.extractConjuncts(predicate));
predicate = Utils.compoundAnd(conjuncts);

Map<ScalarOperator, ValueDescriptor> extractMap = extractImpl(predicate);

Expand All @@ -78,7 +79,6 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu
return predicate;
}

predicate = Utils.compoundAnd(Lists.newArrayList(conjuncts));
if (isOnlyOrCompound(predicate)) {
Set<ColumnRefOperator> c = Sets.newHashSet(Utils.extractColumnRef(predicate));
if (c.size() == extractMap.size() &&
Expand All @@ -91,6 +91,17 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu
List<ScalarOperator> cs = Utils.extractConjuncts(predicate);
Set<ColumnRefOperator> cf = new HashSet<>(Utils.extractColumnRef(predicate));

// getSourceCount = cs.size() means that all and components have the same column ref
// and it can be merged into one range predicate. mistakenly, when the predicate is
// date_trunc(YEAR, dt) = '2024-01-01' AND mode = 'Buzz' AND
// date_trunc(YEAR, dt) = '2024-01-01' AND mode = 'Buzz' //duplicate
//
// only mode = 'Buzz' is a extractable range predicate, so its corresponding ValueDescriptor's
// sourceCount = 2(since it occurs twice), and cs(it is also 2 in this example)are number of
// unique column refs of the predicate, the two values are properly equivalent, so it yields
// wrong result.
//
// Components of AND/OR should be deduplicated at first to avoid this issue.
if (extractMap.values().stream().allMatch(valueDescriptor -> valueDescriptor.sourceCount == cs.size())
&& extractMap.size() == cf.size()) {
return extractExpr;
Expand Down Expand Up @@ -454,4 +465,4 @@ private static boolean isOnlyOrCompound(ScalarOperator predicate) {
return true;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

package com.starrocks.sql.optimizer.rewrite;

import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.collect.Lists;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.scalar.BetweenPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator;
Expand All @@ -18,8 +23,13 @@
import com.starrocks.sql.optimizer.rewrite.scalar.NormalizePredicateRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ReduceCastRule;
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedPredicateRule;
import org.junit.Assert;
import org.junit.Test;

import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.stream.Collectors;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

Expand Down Expand Up @@ -119,4 +129,33 @@ public void testNormalizeIsNull() {
.rewrite(isnotNull, ScalarOperatorRewriter.DEFAULT_REWRITE_SCAN_PREDICATE_RULES);
assertEquals(ConstantOperator.TRUE, rewritten2);
}

@Test
public void testRangeExtract() {
Supplier<ScalarOperator> predicate1Maker = () -> {
ColumnRefOperator col1 = new ColumnRefOperator(1, Type.DATE, "dt", false);
CallOperator call = new CallOperator(FunctionSet.DATE_TRUNC, Type.DATE,
Arrays.asList(ConstantOperator.createVarchar("YEAR"), col1));
return new BinaryPredicateOperator(BinaryPredicateOperator.BinaryType.EQ, call,
ConstantOperator.createDate(LocalDateTime.of(2024, 1, 1, 0, 0, 0)));
};

Supplier<ScalarOperator> predicate2Maker = () -> {
ColumnRefOperator col2 = new ColumnRefOperator(2, Type.VARCHAR, "mode", false);
return new BinaryPredicateOperator(BinaryPredicateOperator.BinaryType.EQ, col2,
ConstantOperator.createVarchar("Buzz"));
};

ScalarOperator predicate1 = Utils.compoundAnd(predicate1Maker.get(), predicate2Maker.get());
ScalarOperator predicate2 = Utils.compoundAnd(predicate1Maker.get(), predicate2Maker.get());
ScalarOperator predicates = Utils.compoundAnd(predicate1, predicate2);

ScalarRangePredicateExtractor rangeExtractor = new ScalarRangePredicateExtractor();
ScalarOperator result = rangeExtractor.rewriteOnlyColumn(Utils.compoundAnd(Utils.extractConjuncts(predicates)
.stream().map(rangeExtractor::rewriteOnlyColumn).collect(Collectors.toList())));
Preconditions.checkState(result != null);
String expect = "date_trunc(YEAR, 1: dt) = 2024-01-01 AND 2: mode = Buzz";
String actual = result.toString();
Assert.assertEquals(actual, expect, actual);
}
}
56 changes: 29 additions & 27 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/SubqueryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1794,21 +1794,22 @@ public void testNestSubquery() throws Exception {
public void testCorrelatedPredicateRewrite_1() throws Exception {
String sql = "select v1 from t0 where v1 = 1 or v2 in (select v4 from t1 where v2 = v4 and v5 = 1)";
String plan = getFragmentPlan(sql);
System.out.println(plan);
assertContains(plan, "7:AGGREGATE (merge finalize)\n" +
" | group by: 8: v4\n" +
" | \n" +
" 6:EXCHANGE");
assertContains(plan, "13:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | output: count(1), count(9: v4)\n" +
" | group by: 10: v4\n" +
" | \n" +
" 12:Project\n" +
" | <slot 9> : 4: v4\n" +
" | <slot 10> : 4: v4\n" +
assertContains(plan, " |----14:AGGREGATE (merge finalize)\n" +
" | | output: count(11: countRows), count(12: countNotNulls)\n" +
" | | group by: 10: v4\n" +
" | | \n" +
" | 13:EXCHANGE");
assertContains(plan, " 9:HASH JOIN\n" +
" | join op: LEFT OUTER JOIN (BUCKET_SHUFFLE(S))\n" +
" | colocate: false, reason: \n" +
" | equal join conjunct: 2: v2 = 8: v4\n" +
" | \n" +
" 11:EXCHANGE");
" |----8:AGGREGATE (merge finalize)\n" +
" | | group by: 8: v4\n" +
" | | \n" +
" | 7:EXCHANGE\n" +
" | \n" +
" 3:EXCHANGE");
}

@Test
Expand All @@ -1817,20 +1818,21 @@ public void testCorrelatedPredicateRewrite_2() throws Exception {
"or v2 in (select v4 from t1 where v2 = v4 and v5 = 1)";

String plan = getFragmentPlan(sql);
assertContains(plan, "24:AGGREGATE (merge finalize)\n" +
" | group by: 12: v4\n" +
" | \n" +
" 23:EXCHANGE");
assertContains(plan, "30:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | output: count(1), count(13: v4)\n" +
" | group by: 14: v4\n" +
" | \n" +
" 29:Project\n" +
" | <slot 13> : 8: v4\n" +
" | <slot 14> : 8: v4\n" +
assertContains(plan, " |----30:AGGREGATE (merge finalize)\n" +
" | | output: count(15: countRows), count(16: countNotNulls)\n" +
" | | group by: 14: v4\n" +
" | | \n" +
" | 29:EXCHANGE");
assertContains(plan, " 17:HASH JOIN\n" +
" | join op: LEFT OUTER JOIN (BUCKET_SHUFFLE(S))\n" +
" | colocate: false, reason: \n" +
" | equal join conjunct: 1: v1 = 19: v5\n" +
" | \n" +
" 28:EXCHANGE");
" |----16:AGGREGATE (merge finalize)\n" +
" | | output: count(20: countRows), count(21: countNotNulls)\n" +
" | | group by: 19: v5\n" +
" | | \n" +
" | 15:EXCHANGE");
}

@Test
Expand Down

0 comments on commit bcc1c1a

Please sign in to comment.