From f73e7bdd1d68bd33752aa76b19a2afa5ca16f5f5 Mon Sep 17 00:00:00 2001 From: satanson Date: Mon, 9 Sep 2024 16:05:36 +0800 Subject: [PATCH] [BugFix] RangePredicateExtractor discard predicates mistakenly Signed-off-by: satanson --- .../ScalarRangePredicateExtractor.java | 13 ++++++- .../rewrite/ScalarOperatorRewriterTest.java | 37 +++++++++++++++++++ .../sql/plan/PruneUKFKJoinRuleTest.java | 2 +- .../com/starrocks/sql/plan/SubqueryTest.java | 25 ++++++++----- 4 files changed, 66 insertions(+), 11 deletions(-) diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ScalarRangePredicateExtractor.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ScalarRangePredicateExtractor.java index 22050d77c48d3..85ede28c5b7b2 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ScalarRangePredicateExtractor.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ScalarRangePredicateExtractor.java @@ -57,6 +57,7 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu Set conjuncts = Sets.newLinkedHashSet(); conjuncts.addAll(Utils.extractConjuncts(predicate)); + predicate = Utils.compoundAnd(conjuncts); Map extractMap = extractImpl(predicate); @@ -82,7 +83,6 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu return predicate; } - predicate = Utils.compoundAnd(Lists.newArrayList(conjuncts)); if (isOnlyOrCompound(predicate)) { Set c = Sets.newHashSet(Utils.extractColumnRef(predicate)); if (c.size() == extractMap.size() && @@ -95,6 +95,17 @@ private ScalarOperator rewrite(ScalarOperator predicate, boolean onlyExtractColu List cs = Utils.extractConjuncts(predicate); Set cf = new HashSet<>(Utils.extractColumnRef(predicate)); + // getSourceCount = cs.size() means that all and components have the same column ref + // and it can be merged into one range predicate. mistakenly, when the predicate is + // date_trunc(YEAR, dt) = '2024-01-01' AND mode = 'Buzz' AND + // date_trunc(YEAR, dt) = '2024-01-01' AND mode = 'Buzz' //duplicate + // + // only mode = 'Buzz' is a extractable range predicate, so its corresponding ValueDescriptor's + // sourceCount = 2(since it occurs twice), and cs(it is also 2 in this example)are number of + // unique column refs of the predicate, the two values are properly equivalent, so it yields + // wrong result. + // + // Components of AND/OR should be deduplicated at first to avoid this issue. if (extractMap.values().stream().allMatch(valueDescriptor -> valueDescriptor.getSourceCount() == cs.size()) && extractMap.size() == cf.size()) { if (result.size() == conjuncts.size()) { diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/optimizer/rewrite/ScalarOperatorRewriterTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/optimizer/rewrite/ScalarOperatorRewriterTest.java index 756820e32bcb2..b25236c810c36 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/optimizer/rewrite/ScalarOperatorRewriterTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/optimizer/rewrite/ScalarOperatorRewriterTest.java @@ -14,12 +14,17 @@ package com.starrocks.sql.optimizer.rewrite; +import com.google.common.base.Preconditions; +import com.google.common.base.Supplier; import com.google.common.collect.Lists; import com.starrocks.analysis.BinaryType; +import com.starrocks.catalog.FunctionSet; import com.starrocks.catalog.Type; +import com.starrocks.sql.optimizer.Utils; import com.starrocks.sql.optimizer.operator.OperatorType; import com.starrocks.sql.optimizer.operator.scalar.BetweenPredicateOperator; import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator; +import com.starrocks.sql.optimizer.operator.scalar.CallOperator; import com.starrocks.sql.optimizer.operator.scalar.CastOperator; import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator; @@ -35,6 +40,10 @@ import org.junit.Assert; import org.junit.Test; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.stream.Collectors; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -163,4 +172,32 @@ public void testNormalizeIsNull() { .rewrite(isnotNull, ScalarOperatorRewriter.DEFAULT_REWRITE_SCAN_PREDICATE_RULES); Assert.assertEquals(ConstantOperator.TRUE, rewritten2); } + + @Test + public void testRangeExtract() { + Supplier predicate1Maker = () -> { + ColumnRefOperator col1 = new ColumnRefOperator(1, Type.DATE, "dt", false); + CallOperator call = new CallOperator(FunctionSet.DATE_TRUNC, Type.DATE, + Arrays.asList(ConstantOperator.createVarchar("YEAR"), col1)); + return new BinaryPredicateOperator(BinaryType.EQ, call, ConstantOperator.createDate( + LocalDateTime.of(2024, 1, 1, 0, 0, 0))); + }; + + Supplier predicate2Maker = () -> { + ColumnRefOperator col2 = new ColumnRefOperator(2, Type.VARCHAR, "mode", false); + return new BinaryPredicateOperator(BinaryType.EQ, col2, ConstantOperator.createVarchar("Buzz")); + }; + + ScalarOperator predicate1 = Utils.compoundAnd(predicate1Maker.get(), predicate2Maker.get()); + ScalarOperator predicate2 = Utils.compoundAnd(predicate1Maker.get(), predicate2Maker.get()); + ScalarOperator predicates = Utils.compoundAnd(predicate1, predicate2); + + ScalarRangePredicateExtractor rangeExtractor = new ScalarRangePredicateExtractor(); + ScalarOperator result = rangeExtractor.rewriteOnlyColumn(Utils.compoundAnd(Utils.extractConjuncts(predicates) + .stream().map(rangeExtractor::rewriteOnlyColumn).collect(Collectors.toList()))); + Preconditions.checkState(result != null); + String expect = "date_trunc(YEAR, 1: dt) = 2024-01-01 AND 2: mode = Buzz"; + String actual = result.toString(); + Assert.assertEquals(actual, expect, actual); + } } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneUKFKJoinRuleTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneUKFKJoinRuleTest.java index 24f85caf3e9f4..27ee16afb27bf 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneUKFKJoinRuleTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneUKFKJoinRuleTest.java @@ -137,7 +137,7 @@ public void canPrune() throws Exception { "t_uk join t_fk on t_uk.v1 = t_fk.v1 and t_fk.v1 = 5"; String plan = getFragmentPlan(sql); assertNotContains(plan, "HASH JOIN"); - assertNotContains(plan, "v1 IS NOT NULL"); + assertContains(plan, "v1 IS NOT NULL"); assertContains(plan, "4: v1 = 5"); } } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/SubqueryTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/SubqueryTest.java index 3942f2922647a..e98a2fbcd6b3c 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/SubqueryTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/SubqueryTest.java @@ -1799,10 +1799,10 @@ public void testCorrelatedPredicateRewrite_1() throws Exception { String sql = "select v1 from t0 where v1 = 1 or v2 in (select v4 from t1 where v2 = v4 and v5 = 1)"; String plan = getFragmentPlan(sql); - assertContains(plan, "15:AGGREGATE (merge finalize)\n" + - " | group by: 8: v4\n" + - " | \n" + - " 14:EXCHANGE"); + assertContains(plan, " |----15:AGGREGATE (merge finalize)\n" + + " | | group by: 8: v4\n" + + " | | \n" + + " | 14:EXCHANGE"); assertContains(plan, "9:HASH JOIN\n" + " | join op: RIGHT OUTER JOIN (BUCKET_SHUFFLE(S))\n" + " | colocate: false, reason: \n" + @@ -1823,16 +1823,23 @@ public void testCorrelatedPredicateRewrite_2() throws Exception { "or v2 in (select v4 from t1 where v2 = v4 and v5 = 1)"; String plan = getFragmentPlan(sql); - assertContains(plan, "33:AGGREGATE (merge finalize)\n" + - " | group by: 12: v4\n" + + assertContains(plan, " |----32:AGGREGATE (merge finalize)\n" + + " | | group by: 12: v4\n" + + " | | \n" + + " | 31:EXCHANGE"); + assertContains(plan, " 27:Project\n" + + " | : 1: v1\n" + + " | : 2: v2\n" + + " | : 7: expr\n" + + " | : 15: countRows\n" + + " | : 16: countNotNulls\n" + " | \n" + - " 32:EXCHANGE"); - assertContains(plan, "27:HASH JOIN\n" + + " 26:HASH JOIN\n" + " | join op: RIGHT OUTER JOIN (BUCKET_SHUFFLE(S))\n" + " | colocate: false, reason: \n" + " | equal join conjunct: 14: v4 = 2: v2\n" + " | \n" + - " |----26:EXCHANGE\n" + + " |----25:EXCHANGE\n" + " | \n" + " 6:AGGREGATE (merge finalize)\n" + " | output: count(15: countRows), count(16: countNotNulls)\n" +