diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateToSemiJoinRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateToSemiJoinRule.java deleted file mode 100644 index 327921df713d..000000000000 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateToSemiJoinRule.java +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.calcite.rel.rules; - -import java.util.ArrayList; -import java.util.List; -import javax.annotation.Nullable; -import org.apache.calcite.plan.RelOptCluster; -import org.apache.calcite.plan.RelOptRule; -import org.apache.calcite.plan.RelOptRuleCall; -import org.apache.calcite.plan.RelOptUtil; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.Aggregate; -import org.apache.calcite.rel.core.Join; -import org.apache.calcite.rel.core.JoinInfo; -import org.apache.calcite.rel.rules.CoreRules; -import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.tools.RelBuilder; -import org.apache.calcite.tools.RelBuilderFactory; -import org.apache.calcite.util.ImmutableBitSet; -import org.apache.calcite.util.ImmutableIntList; - - -/** - * SemiJoinRule that matches an Aggregate on top of a Join with an Aggregate as its right child. - * - * @see CoreRules#PROJECT_TO_SEMI_JOIN - */ -public class PinotAggregateToSemiJoinRule extends RelOptRule { - public static final PinotAggregateToSemiJoinRule INSTANCE = - new PinotAggregateToSemiJoinRule(PinotRuleUtils.PINOT_REL_FACTORY); - - public PinotAggregateToSemiJoinRule(RelBuilderFactory factory) { - super(operand(Aggregate.class, - some(operand(Join.class, some(operand(RelNode.class, any()), operand(Aggregate.class, any()))))), factory, - null); - } - - @Override - public void onMatch(RelOptRuleCall call) { - final Aggregate topAgg = call.rel(0); - final Join join = (Join) PinotRuleUtils.unboxRel(topAgg.getInput()); - final RelNode left = PinotRuleUtils.unboxRel(join.getInput(0)); - final Aggregate rightAgg = (Aggregate) PinotRuleUtils.unboxRel(join.getInput(1)); - perform(call, topAgg, join, left, rightAgg); - } - - - protected void perform(RelOptRuleCall call, @Nullable Aggregate topAgg, - Join join, RelNode left, Aggregate rightAgg) { - final RelOptCluster cluster = join.getCluster(); - final RexBuilder rexBuilder = cluster.getRexBuilder(); - if (topAgg != null) { - final ImmutableBitSet aggBits = ImmutableBitSet.of(RelOptUtil.getAllFields(topAgg)); - final ImmutableBitSet rightBits = - ImmutableBitSet.range(left.getRowType().getFieldCount(), - join.getRowType().getFieldCount()); - if (aggBits.intersects(rightBits)) { - return; - } - } else { - if (join.getJoinType().projectsRight() - && !isEmptyAggregate(rightAgg)) { - return; - } - } - final JoinInfo joinInfo = join.analyzeCondition(); - if (!joinInfo.rightSet().equals( - ImmutableBitSet.range(rightAgg.getGroupCount()))) { - // Rule requires that aggregate key to be the same as the join key. - // By the way, neither a super-set nor a sub-set would work. - return; - } - if (!joinInfo.isEqui()) { - return; - } - final RelBuilder relBuilder = call.builder(); - relBuilder.push(left); - switch (join.getJoinType()) { - case SEMI: - case INNER: - final List newRightKeyBuilder = new ArrayList<>(); - final List aggregateKeys = rightAgg.getGroupSet().asList(); - for (int key : joinInfo.rightKeys) { - newRightKeyBuilder.add(aggregateKeys.get(key)); - } - final ImmutableIntList newRightKeys = ImmutableIntList.copyOf(newRightKeyBuilder); - relBuilder.push(rightAgg.getInput()); - final RexNode newCondition = - RelOptUtil.createEquiJoinCondition(relBuilder.peek(2, 0), - joinInfo.leftKeys, relBuilder.peek(2, 1), newRightKeys, - rexBuilder); - relBuilder.semiJoin(newCondition).hints(join.getHints()); - break; - - case LEFT: - // The right-hand side produces no more than 1 row (because of the - // Aggregate) and no fewer than 1 row (because of LEFT), and therefore - // we can eliminate the semi-join. - break; - - default: - throw new AssertionError(join.getJoinType()); - } - if (topAgg != null) { - relBuilder.aggregate(relBuilder.groupKey(topAgg.getGroupSet()), topAgg.getAggCallList()); - } - final RelNode relNode = relBuilder.build(); - call.transformTo(relNode); - } - - private static boolean isEmptyAggregate(Aggregate aggregate) { - return aggregate.getRowType().getFieldCount() == 0; - } -} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java index fdb75ee78f19..e831e7460a52 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java @@ -73,7 +73,6 @@ private PinotQueryRuleSets() { // join and semi-join rules CoreRules.PROJECT_TO_SEMI_JOIN, - PinotAggregateToSemiJoinRule.INSTANCE, // convert non-all union into all-union + distinct CoreRules.UNION_TO_DISTINCT, diff --git a/pinot-query-planner/src/test/resources/queries/JoinPlans.json b/pinot-query-planner/src/test/resources/queries/JoinPlans.json index fb63399fac71..d48795dc30cd 100644 --- a/pinot-query-planner/src/test/resources/queries/JoinPlans.json +++ b/pinot-query-planner/src/test/resources/queries/JoinPlans.json @@ -111,7 +111,7 @@ }, { "description": "Inner join with group by", - "sql": "EXPLAIN PLAN FOR SELECT a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1", + "sql": "EXPLAIN PLAN FOR SELECT a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1", "output": [ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, $2)])", @@ -222,6 +222,21 @@ }, { "description": "Semi join with IN clause", + "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a WHERE col3 IN (SELECT col3 FROM b)", + "output": [ + "Execution Plan", + "\nLogicalProject(col1=[$0], col2=[$1])", + "\n LogicalJoin(condition=[=($2, $3)], joinType=[semi])", + "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", + "\n LogicalTableScan(table=[[default, a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n LogicalProject(col3=[$2])", + "\n LogicalTableScan(table=[[default, b]])", + "\n" + ] + }, + { + "description": "Semi join with IN clause and join strategy override", "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(join_strategy = 'hash') */ col1, col2 FROM a WHERE col3 IN (SELECT col3 FROM b)", "output": [ "Execution Plan", @@ -237,7 +252,60 @@ ] }, { - "description": "Semi join with multiple IN clause", + "description": "Semi join with IN clause on distinct values", + "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a WHERE col3 IN (SELECT DISTINCT col3 FROM b)", + "output": [ + "Execution Plan", + "\nLogicalProject(col1=[$0], col2=[$1])", + "\n LogicalJoin(condition=[=($2, $3)], joinType=[semi])", + "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", + "\n LogicalTableScan(table=[[default, a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", + "\n LogicalTableScan(table=[[default, b]])", + "\n" + ] + }, + { + "description": "Semi join with IN clause then aggregate with group by", + "sql": "EXPLAIN PLAN FOR SELECT col1, SUM(col6) FROM a WHERE col3 IN (SELECT col3 FROM b) GROUP BY col1", + "output": [ + "Execution Plan", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", + "\n LogicalJoin(condition=[=($1, $3)], joinType=[semi])", + "\n LogicalProject(col1=[$0], col3=[$2], col6=[$5])", + "\n LogicalTableScan(table=[[default, a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n LogicalProject(col3=[$2])", + "\n LogicalTableScan(table=[[default, b]])", + "\n" + ] + }, + { + "description": "Semi join with IN clause of distinct values then aggregate with group by", + "sql": "EXPLAIN PLAN FOR SELECT col1, SUM(col6) FROM a WHERE col3 IN (SELECT DISTINCT col3 FROM b) GROUP BY col1", + "output": [ + "Execution Plan", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", + "\n LogicalJoin(condition=[=($1, $3)], joinType=[semi])", + "\n LogicalProject(col1=[$0], col3=[$2], col6=[$5])", + "\n LogicalTableScan(table=[[default, a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", + "\n LogicalTableScan(table=[[default, b]])", + "\n" + ] + }, + { + "description": "Semi join with multiple IN clause and join strategy override", "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(join_strategy = 'hash') */ col1, col2 FROM a WHERE col2 = 'test' AND col3 IN (SELECT col3 FROM b WHERE col1='foo') AND col3 IN (SELECT col3 FROM b WHERE col1='bar') AND col3 IN (SELECT col3 FROM b WHERE col1='foobar')", "output": [ "Execution Plan", diff --git a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json index f26a1330169b..998bf0560633 100644 --- a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json +++ b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json @@ -293,6 +293,58 @@ "\n" ] }, + { + "description": "agg + semi-join on colocated tables then group by on partition column with join and agg hint", + "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(is_colocated_by_join_keys='true'), aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b /*+ tableOptions(partition_function='hashcode', partition_key='col1', partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1", + "output": [ + "Execution Plan", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[DIRECT])", + "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", + "\n LogicalProject(col2=[$1], col3=[$2])", + "\n LogicalTableScan(table=[[default, a]])", + "\n PinotLogicalExchange(distribution=[hash[0]], relExchangeType=[PIPELINE_BREAKER])", + "\n LogicalProject(col1=[$0])", + "\n LogicalFilter(condition=[>($2, 0)])", + "\n LogicalTableScan(table=[[default, b]])", + "\n" + ] + }, + { + "description": "agg + semi-join with distinct values on colocated tables then group by on partition column", + "sql": "EXPLAIN PLAN FOR SELECT a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT DISTINCT col1 FROM b /*+ tableOptions(partition_function='hashcode', partition_key='col1', partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1", + "output": [ + "Execution Plan", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[LEAF])", + "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", + "\n LogicalProject(col2=[$1], col3=[$2])", + "\n LogicalTableScan(table=[[default, a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[LEAF])", + "\n LogicalFilter(condition=[>($2, 0)])", + "\n LogicalTableScan(table=[[default, b]])", + "\n" + ] + }, + { + "description": "agg + semi-join with distinct values on colocated tables then group by on partition column with join and agg hint", + "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(is_colocated_by_join_keys='true'), aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT DISTINCT col1 FROM b /*+ tableOptions(partition_function='hashcode', partition_key='col1', partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1", + "output": [ + "Execution Plan", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[DIRECT])", + "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", + "\n LogicalProject(col2=[$1], col3=[$2])", + "\n LogicalTableScan(table=[[default, a]])", + "\n PinotLogicalExchange(distribution=[hash[0]], relExchangeType=[PIPELINE_BREAKER])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[DIRECT])", + "\n LogicalFilter(condition=[>($2, 0)])", + "\n LogicalTableScan(table=[[default, b]])", + "\n" + ] + }, { "description": "agg + semi-join on pre-partitioned main tables then group by on partition column", "sql": "EXPLAIN PLAN FOR SELECT a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b WHERE b.col3 > 0) GROUP BY 1",