diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index f6c212a4b3c9..b41d12628a89 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -33,6 +33,7 @@ use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, }; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::utils::{ expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_window_exprs, @@ -359,24 +360,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } fn try_process_aggregate_unnest(&self, input: LogicalPlan) -> Result { - match &input { + match input { LogicalPlan::Aggregate(agg) => { + let agg_expr = agg.aggr_expr.clone(); let (new_input, new_group_by_exprs) = self.try_process_group_by_unnest(agg)?; LogicalPlanBuilder::from(new_input) - .aggregate(new_group_by_exprs, agg.aggr_expr.clone())? + .aggregate(new_group_by_exprs, agg_expr)? .build() } - LogicalPlan::Filter(filter) => match filter.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - let (new_input, new_group_by_exprs) = - self.try_process_group_by_unnest(agg)?; - LogicalPlanBuilder::from(new_input) - .aggregate(new_group_by_exprs, agg.aggr_expr.clone())? - .filter(filter.predicate.clone())? - .build() - } - _ => Ok(input), + LogicalPlan::Filter(mut filter) => { + filter.input = Arc::new(self.try_process_aggregate_unnest(unwrap_arc(filter.input))?); + Ok(LogicalPlan::Filter(filter)) }, _ => Ok(input), } @@ -386,13 +381,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Return the new input and group_by_exprs of Aggregate. fn try_process_group_by_unnest( &self, - agg: &Aggregate, + agg: Aggregate, ) -> Result<(LogicalPlan, Vec)> { let mut aggr_expr_using_columns: Option> = None; - let input = agg.input.as_ref(); - let group_by_exprs = &agg.group_expr; - let aggr_exprs = &agg.aggr_expr; + let Aggregate { + input, + group_expr: group_by_exprs, + aggr_expr: aggr_exprs, + .. + } = agg; // process unnest of group_by_exprs, and input of agg will be rewritten // for example: @@ -410,8 +408,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Projection: tab.array_col AS unnest(tab.array_col) // TableScan: tab // ``` - let mut intermediate_plan = input.clone(); - let mut intermediate_select_exprs = group_by_exprs.to_vec(); + let mut intermediate_plan = unwrap_arc(input); + let mut intermediate_select_exprs = group_by_exprs; loop { let mut unnest_columns = vec![]; @@ -442,7 +440,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some(exprs) => (*exprs).clone(), None => { let mut columns = HashSet::new(); - for expr in aggr_exprs { + for expr in &aggr_exprs { expr.apply(|expr| { if let Expr::Column(c) = expr { columns.insert(Expr::Column(c.clone()));