Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] RangePredicateExtractor discard predicates mistakenly (backport #50854) #50906

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu

Set<ScalarOperator> conjuncts = Sets.newLinkedHashSet();
conjuncts.addAll(Utils.extractConjuncts(predicate));
predicate = Utils.compoundAnd(conjuncts);

Map<ScalarOperator, ValueDescriptor> extractMap = extractImpl(predicate);

Expand All @@ -78,7 +79,6 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu
return predicate;
}

predicate = Utils.compoundAnd(Lists.newArrayList(conjuncts));
if (isOnlyOrCompound(predicate)) {
Set<ColumnRefOperator> c = Sets.newHashSet(Utils.extractColumnRef(predicate));
if (c.size() == extractMap.size() &&
Expand All @@ -91,6 +91,17 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu
List<ScalarOperator> cs = Utils.extractConjuncts(predicate);
Set<ColumnRefOperator> cf = new HashSet<>(Utils.extractColumnRef(predicate));

// getSourceCount = cs.size() means that all and components have the same column ref
// and it can be merged into one range predicate. mistakenly, when the predicate is
// date_trunc(YEAR, dt) = '2024-01-01' AND mode = 'Buzz' AND
// date_trunc(YEAR, dt) = '2024-01-01' AND mode = 'Buzz' //duplicate
//
// only mode = 'Buzz' is a extractable range predicate, so its corresponding ValueDescriptor's
// sourceCount = 2(since it occurs twice), and cs(it is also 2 in this example)are number of
// unique column refs of the predicate, the two values are properly equivalent, so it yields
// wrong result.
//
// Components of AND/OR should be deduplicated at first to avoid this issue.
if (extractMap.values().stream().allMatch(valueDescriptor -> valueDescriptor.sourceCount == cs.size())
&& extractMap.size() == cf.size()) {
return extractExpr;
Expand Down Expand Up @@ -454,4 +465,4 @@ private static boolean isOnlyOrCompound(ScalarOperator predicate) {
return true;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

package com.starrocks.sql.optimizer.rewrite;

import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.collect.Lists;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.scalar.BetweenPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator;
Expand All @@ -18,8 +23,13 @@
import com.starrocks.sql.optimizer.rewrite.scalar.NormalizePredicateRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ReduceCastRule;
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedPredicateRule;
import org.junit.Assert;
import org.junit.Test;

import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.stream.Collectors;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

Expand Down Expand Up @@ -119,4 +129,33 @@ public void testNormalizeIsNull() {
.rewrite(isnotNull, ScalarOperatorRewriter.DEFAULT_REWRITE_SCAN_PREDICATE_RULES);
assertEquals(ConstantOperator.TRUE, rewritten2);
}

@Test
public void testRangeExtract() {
Supplier<ScalarOperator> predicate1Maker = () -> {
ColumnRefOperator col1 = new ColumnRefOperator(1, Type.DATE, "dt", false);
CallOperator call = new CallOperator(FunctionSet.DATE_TRUNC, Type.DATE,
Arrays.asList(ConstantOperator.createVarchar("YEAR"), col1));
return new BinaryPredicateOperator(BinaryPredicateOperator.BinaryType.EQ, call,
ConstantOperator.createDate(LocalDateTime.of(2024, 1, 1, 0, 0, 0)));
};

Supplier<ScalarOperator> predicate2Maker = () -> {
ColumnRefOperator col2 = new ColumnRefOperator(2, Type.VARCHAR, "mode", false);
return new BinaryPredicateOperator(BinaryPredicateOperator.BinaryType.EQ, col2,
ConstantOperator.createVarchar("Buzz"));
};

ScalarOperator predicate1 = Utils.compoundAnd(predicate1Maker.get(), predicate2Maker.get());
ScalarOperator predicate2 = Utils.compoundAnd(predicate1Maker.get(), predicate2Maker.get());
ScalarOperator predicates = Utils.compoundAnd(predicate1, predicate2);

ScalarRangePredicateExtractor rangeExtractor = new ScalarRangePredicateExtractor();
ScalarOperator result = rangeExtractor.rewriteOnlyColumn(Utils.compoundAnd(Utils.extractConjuncts(predicates)
.stream().map(rangeExtractor::rewriteOnlyColumn).collect(Collectors.toList())));
Preconditions.checkState(result != null);
String expect = "date_trunc(YEAR, 1: dt) = 2024-01-01 AND 2: mode = Buzz";
String actual = result.toString();
Assert.assertEquals(actual, expect, actual);
}
}
56 changes: 29 additions & 27 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/SubqueryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1794,21 +1794,22 @@ public void testNestSubquery() throws Exception {
public void testCorrelatedPredicateRewrite_1() throws Exception {
String sql = "select v1 from t0 where v1 = 1 or v2 in (select v4 from t1 where v2 = v4 and v5 = 1)";
String plan = getFragmentPlan(sql);
System.out.println(plan);
assertContains(plan, "7:AGGREGATE (merge finalize)\n" +
" | group by: 8: v4\n" +
" | \n" +
" 6:EXCHANGE");
assertContains(plan, "13:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | output: count(1), count(9: v4)\n" +
" | group by: 10: v4\n" +
" | \n" +
" 12:Project\n" +
" | <slot 9> : 4: v4\n" +
" | <slot 10> : 4: v4\n" +
assertContains(plan, " |----14:AGGREGATE (merge finalize)\n" +
" | | output: count(11: countRows), count(12: countNotNulls)\n" +
" | | group by: 10: v4\n" +
" | | \n" +
" | 13:EXCHANGE");
assertContains(plan, " 9:HASH JOIN\n" +
" | join op: LEFT OUTER JOIN (BUCKET_SHUFFLE(S))\n" +
" | colocate: false, reason: \n" +
" | equal join conjunct: 2: v2 = 8: v4\n" +
" | \n" +
" 11:EXCHANGE");
" |----8:AGGREGATE (merge finalize)\n" +
" | | group by: 8: v4\n" +
" | | \n" +
" | 7:EXCHANGE\n" +
" | \n" +
" 3:EXCHANGE");
}

@Test
Expand All @@ -1817,20 +1818,21 @@ public void testCorrelatedPredicateRewrite_2() throws Exception {
"or v2 in (select v4 from t1 where v2 = v4 and v5 = 1)";

String plan = getFragmentPlan(sql);
assertContains(plan, "24:AGGREGATE (merge finalize)\n" +
" | group by: 12: v4\n" +
" | \n" +
" 23:EXCHANGE");
assertContains(plan, "30:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | output: count(1), count(13: v4)\n" +
" | group by: 14: v4\n" +
" | \n" +
" 29:Project\n" +
" | <slot 13> : 8: v4\n" +
" | <slot 14> : 8: v4\n" +
assertContains(plan, " |----30:AGGREGATE (merge finalize)\n" +
" | | output: count(15: countRows), count(16: countNotNulls)\n" +
" | | group by: 14: v4\n" +
" | | \n" +
" | 29:EXCHANGE");
assertContains(plan, " 17:HASH JOIN\n" +
" | join op: LEFT OUTER JOIN (BUCKET_SHUFFLE(S))\n" +
" | colocate: false, reason: \n" +
" | equal join conjunct: 1: v1 = 19: v5\n" +
" | \n" +
" 28:EXCHANGE");
" |----16:AGGREGATE (merge finalize)\n" +
" | | output: count(20: countRows), count(21: countNotNulls)\n" +
" | | group by: 19: v5\n" +
" | | \n" +
" | 15:EXCHANGE");
}

@Test
Expand Down
Loading