diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java index 45ded8f1e538..452186b9098a 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java @@ -54,9 +54,7 @@ public class TableResizer { private final int _numGroupByExpressions; private final Map _groupByExpressionIndexMap; private final AggregationFunction[] _aggregationFunctions; - private final Map _aggregationFunctionIndexMap; private final Map, Integer> _filteredAggregationIndexMap; - private final List> _filteredAggregationFunctions; private final int _numOrderByExpressions; private final OrderByValueExtractor[] _orderByValueExtractors; private final Comparator _intermediateRecordComparator; @@ -82,10 +80,8 @@ public TableResizer(DataSchema dataSchema, boolean hasFinalInput, QueryContext q _aggregationFunctions = queryContext.getAggregationFunctions(); assert _aggregationFunctions != null; - _aggregationFunctionIndexMap = queryContext.getAggregationFunctionIndexMap(); - assert _aggregationFunctionIndexMap != null; _filteredAggregationIndexMap = queryContext.getFilteredAggregationsIndexMap(); - _filteredAggregationFunctions = queryContext.getFilteredAggregationFunctions(); + assert _filteredAggregationIndexMap != null; List orderByExpressions = queryContext.getOrderByExpressions(); assert orderByExpressions != null; @@ -148,26 +144,26 @@ private OrderByValueExtractor getOrderByValueExtractor(ExpressionContext express FunctionContext function = expression.getFunction(); Preconditions.checkState(function != null, "Failed to find ORDER-BY expression: %s in the GROUP-BY clause", expression); + FunctionContext aggregation; + FilterContext filter; if (function.getType() == FunctionContext.Type.AGGREGATION) { // Aggregation function - int index = _aggregationFunctionIndexMap.get(function); - // For final aggregate result, we can handle it the same way as group key - return _hasFinalInput ? new GroupByExpressionExtractor(_numGroupByExpressions + index) - : new AggregationFunctionExtractor(index); + aggregation = function; + filter = null; } else if (function.getType() == FunctionContext.Type.TRANSFORM && "FILTER".equalsIgnoreCase( function.getFunctionName())) { // Filtered aggregation - FunctionContext aggregation = function.getArguments().get(0).getFunction(); - ExpressionContext filterExpression = function.getArguments().get(1); - FilterContext filter = RequestContextUtils.getFilter(filterExpression); - int index = _filteredAggregationIndexMap.get(Pair.of(aggregation, filter)); - // For final aggregate result, we can handle it the same way as group key - return _hasFinalInput ? new GroupByExpressionExtractor(_numGroupByExpressions + index) - : new AggregationFunctionExtractor(index, _filteredAggregationFunctions.get(index).getLeft()); + aggregation = function.getArguments().get(0).getFunction(); + filter = RequestContextUtils.getFilter(function.getArguments().get(1)); } else { // Post-aggregation function return new PostAggregationFunctionExtractor(function); } + + int index = _filteredAggregationIndexMap.get(Pair.of(aggregation, filter)); + // For final aggregate result, we can handle it the same way as group key + return _hasFinalInput ? new GroupByExpressionExtractor(_numGroupByExpressions + index) + : new AggregationFunctionExtractor(index); } /** @@ -441,11 +437,6 @@ private class AggregationFunctionExtractor implements OrderByValueExtractor { _aggregationFunction = _aggregationFunctions[aggregationFunctionIndex]; } - AggregationFunctionExtractor(int aggregationFunctionIndex, AggregationFunction aggregationFunction) { - _index = aggregationFunctionIndex + _numGroupByExpressions; - _aggregationFunction = aggregationFunction; - } - @Override public ColumnDataType getValueType() { return _aggregationFunction.getFinalResultColumnType(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java index 6c4a3d75c3df..aee9261d9480 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java @@ -91,10 +91,9 @@ public class QueryContext { // Pre-calculate the aggregation functions and columns for the query so that it can be shared across all the segments private AggregationFunction[] _aggregationFunctions; - private Map _aggregationFunctionIndexMap; - private boolean _hasFilteredAggregations; private List> _filteredAggregationFunctions; private Map, Integer> _filteredAggregationsIndexMap; + private boolean _hasFilteredAggregations; private Set _columns; // Other properties to be shared across all the segments @@ -272,22 +271,6 @@ public List> getFilteredAggregationFunc return _filteredAggregationFunctions; } - /** - * Returns the filtered aggregation expressions for the query. - */ - public boolean hasFilteredAggregations() { - return _hasFilteredAggregations; - } - - /** - * Returns a map from the AGGREGATION FunctionContext to the index of the corresponding AggregationFunction in the - * aggregation functions array. - */ - @Nullable - public Map getAggregationFunctionIndexMap() { - return _aggregationFunctionIndexMap; - } - /** * Returns a map from the filtered aggregation (pair of AGGREGATION FunctionContext and FILTER FilterContext) to the * index of corresponding AggregationFunction in the aggregation functions array. @@ -297,6 +280,13 @@ public Map, Integer> getFilteredAggregation return _filteredAggregationsIndexMap; } + /** + * Returns the filtered aggregation expressions for the query. + */ + public boolean hasFilteredAggregations() { + return _hasFilteredAggregations; + } + /** * Returns the columns (IDENTIFIER expressions) in the query. */ @@ -619,12 +609,7 @@ private void generateAggregationFunctions(QueryContext queryContext) { for (int i = 0; i < numAggregations; i++) { aggregationFunctions[i] = filteredAggregationFunctions.get(i).getLeft(); } - Map aggregationFunctionIndexMap = new HashMap<>(); - for (Map.Entry, Integer> entry : filteredAggregationsIndexMap.entrySet()) { - aggregationFunctionIndexMap.put(entry.getKey().getLeft(), entry.getValue()); - } queryContext._aggregationFunctions = aggregationFunctions; - queryContext._aggregationFunctionIndexMap = aggregationFunctionIndexMap; queryContext._filteredAggregationFunctions = filteredAggregationFunctions; queryContext._filteredAggregationsIndexMap = filteredAggregationsIndexMap; } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java index 7c74e022af8c..ef331ebf59c5 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java @@ -480,21 +480,21 @@ public void testHardcodedQueries() { assertEquals(aggregationFunctions[3].getResultColumnName(), "sum(col4)"); assertEquals(aggregationFunctions[4].getResultColumnName(), "max(col4)"); assertEquals(aggregationFunctions[5].getResultColumnName(), "max(col1)"); - Map aggregationFunctionIndexMap = queryContext.getAggregationFunctionIndexMap(); - assertNotNull(aggregationFunctionIndexMap); - assertEquals(aggregationFunctionIndexMap.size(), 6); - assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "sum", - Collections.singletonList(ExpressionContext.forIdentifier("col1")))), 0); - assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "max", - Collections.singletonList(ExpressionContext.forIdentifier("col2")))), 1); - assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "min", - Collections.singletonList(ExpressionContext.forIdentifier("col2")))), 2); - assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "sum", - Collections.singletonList(ExpressionContext.forIdentifier("col4")))), 3); - assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "max", - Collections.singletonList(ExpressionContext.forIdentifier("col4")))), 4); - assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "max", - Collections.singletonList(ExpressionContext.forIdentifier("col1")))), 5); + Map, Integer> indexMap = queryContext.getFilteredAggregationsIndexMap(); + assertNotNull(indexMap); + assertEquals(indexMap.size(), 6); + assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "sum", + Collections.singletonList(ExpressionContext.forIdentifier("col1"))), null)), 0); + assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "max", + Collections.singletonList(ExpressionContext.forIdentifier("col2"))), null)), 1); + assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "min", + Collections.singletonList(ExpressionContext.forIdentifier("col2"))), null)), 2); + assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "sum", + Collections.singletonList(ExpressionContext.forIdentifier("col4"))), null)), 3); + assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "max", + Collections.singletonList(ExpressionContext.forIdentifier("col4"))), null)), 4); + assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "max", + Collections.singletonList(ExpressionContext.forIdentifier("col1"))), null)), 5); } // DistinctCountThetaSketch (string literal and escape quote) @@ -540,21 +540,10 @@ public void testFilteredAggregations() { assertTrue(filteredAggregationFunctions.get(1).getLeft() instanceof CountAggregationFunction); assertEquals(filteredAggregationFunctions.get(1).getRight().toString(), "foo < '6'"); - Map aggregationIndexMap = queryContext.getAggregationFunctionIndexMap(); - assertNotNull(aggregationIndexMap); - assertEquals(aggregationIndexMap.size(), 1); - for (Map.Entry entry : aggregationIndexMap.entrySet()) { - FunctionContext aggregation = entry.getKey(); - int index = entry.getValue(); - assertEquals(aggregation.toString(), "count(*)"); - assertTrue(index == 0 || index == 1); - } - - Map, Integer> filteredAggregationsIndexMap = - queryContext.getFilteredAggregationsIndexMap(); - assertNotNull(filteredAggregationsIndexMap); - assertEquals(filteredAggregationsIndexMap.size(), 2); - for (Map.Entry, Integer> entry : filteredAggregationsIndexMap.entrySet()) { + Map, Integer> indexMap = queryContext.getFilteredAggregationsIndexMap(); + assertNotNull(indexMap); + assertEquals(indexMap.size(), 2); + for (Map.Entry, Integer> entry : indexMap.entrySet()) { Pair pair = entry.getKey(); FunctionContext aggregation = pair.getLeft(); FilterContext filter = pair.getRight(); @@ -600,32 +589,10 @@ public void testFilteredAggregations() { assertTrue(filteredAggregationFunctions.get(3).getLeft() instanceof MinAggregationFunction); assertEquals(filteredAggregationFunctions.get(3).getRight().toString(), "salary > '50000'"); - Map aggregationIndexMap = queryContext.getAggregationFunctionIndexMap(); - assertNotNull(aggregationIndexMap); - assertEquals(aggregationIndexMap.size(), 2); - for (Map.Entry entry : aggregationIndexMap.entrySet()) { - FunctionContext aggregation = entry.getKey(); - int index = entry.getValue(); - switch (index) { - case 0: - case 1: - assertEquals(aggregation.toString(), "sum(salary)"); - break; - case 2: - case 3: - assertEquals(aggregation.toString(), "min(salary)"); - break; - default: - fail(); - break; - } - } - - Map, Integer> filteredAggregationsIndexMap = - queryContext.getFilteredAggregationsIndexMap(); - assertNotNull(filteredAggregationsIndexMap); - assertEquals(filteredAggregationsIndexMap.size(), 4); - for (Map.Entry, Integer> entry : filteredAggregationsIndexMap.entrySet()) { + Map, Integer> indexMap = queryContext.getFilteredAggregationsIndexMap(); + assertNotNull(indexMap); + assertEquals(indexMap.size(), 4); + for (Map.Entry, Integer> entry : indexMap.entrySet()) { Pair pair = entry.getKey(); FunctionContext aggregation = pair.getLeft(); FilterContext filter = pair.getRight(); diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java index ca1ea0b1f957..7f93ba75eff8 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java @@ -137,14 +137,6 @@ public class OfflineClusterIntegrationTest extends BaseClusterIntegrationTestSet new StarTreeIndexConfig(Collections.singletonList("DestState"), null, Collections.singletonList(AggregationFunctionColumnPair.COUNT_STAR.toColumnName()), null, 100); private static final String TEST_STAR_TREE_QUERY_2 = "SELECT COUNT(*) FROM mytable WHERE DestState = 'CA'"; - private static final String TEST_STAR_TREE_QUERY_FILTERED_AGG = - "SELECT COUNT(*), COUNT(*) FILTER (WHERE Carrier = 'UA') FROM mytable WHERE DestState = 'CA'"; - // This query contains a filtered aggregation which cannot be solved with startree, but the COUNT(*) still should be - private static final String TEST_STAR_TREE_QUERY_FILTERED_AGG_MIXED = - "SELECT COUNT(*), AVG(ArrDelay) FILTER (WHERE Carrier = 'UA') FROM mytable WHERE DestState = 'CA'"; - private static final StarTreeIndexConfig STAR_TREE_INDEX_CONFIG_3 = - new StarTreeIndexConfig(List.of("Carrier", "DestState"), null, - Collections.singletonList(AggregationFunctionColumnPair.COUNT_STAR.toColumnName()), null, 100); // For default columns test private static final String TEST_EXTRA_COLUMNS_QUERY = "SELECT COUNT(*) FROM mytable WHERE NewAddedIntMetric = 1"; @@ -3472,6 +3464,24 @@ public void testBooleanAggregation() testQuery("SELECT BOOL_OR(CAST(Diverted AS BOOLEAN)) FROM mytable"); } + @Test(dataProvider = "useBothQueryEngines") + public void testGroupByAggregationWithLimitZero(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + testQuery("SELECT Origin, SUM(ArrDelay) FROM mytable GROUP BY Origin LIMIT 0"); + } + + @Test(dataProvider = "useBothQueryEngines") + public void testFilteredAggregationWithGroupByOrdering(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + + // Test the ordering is correctly applied to the correct aggregation (the one without FILTER clause) + // See https://github.com/apache/pinot/pull/13784 + testQuery("SELECT DestCityName, COUNT(*) AS c1, COUNT(*) FILTER (WHERE AirTime = 0) AS c2 FROM mytable " + + "GROUP BY DestCityName ORDER BY c1 DESC LIMIT 10"); + } + private String buildSkipIndexesOption(String columnsAndIndexes) { return "SET " + SKIP_INDEXES + "='" + columnsAndIndexes + "'; "; }