diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f32fed5db5a31..02d1aca267836 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1087,6 +1087,25 @@ impl Expr { } /// Remove an alias from an expression if one exists. + /// + /// If the expression is not an alias, the expression is returned unchanged. + /// This method does not remove aliases from nested expressions. + /// + /// # Example + /// ``` + /// # use datafusion_expr::col; + /// // `foo as "bar"` is unaliased to `foo` + /// let expr = col("foo").alias("bar"); + /// assert_eq!(expr.unalias(), col("foo")); + /// + /// // `foo as "bar" + baz` is not unaliased + /// let expr = col("foo").alias("bar") + col("baz"); + /// assert_eq!(expr.clone().unalias(), expr); + /// + /// // `foo as "bar" as "baz" is unalaised to foo as "bar" + /// let expr = col("foo").alias("bar").alias("baz"); + /// assert_eq!(expr.unalias(), col("foo").alias("bar")); + /// ``` pub fn unalias(self) -> Expr { match self { Expr::Alias(alias) => *alias.expr, @@ -1094,6 +1113,34 @@ impl Expr { } } + /// Recursively potentially multiple aliases from an expression. + /// + /// If the expression is not an alias, the expression is returned unchanged. + /// This method removes directly nested aliases, but not other nested + /// aliases. + /// + /// # Example + /// ``` + /// # use datafusion_expr::col; + /// // `foo as "bar"` is unaliased to `foo` + /// let expr = col("foo").alias("bar"); + /// assert_eq!(expr.unalias_nested(), col("foo")); + /// + /// // `foo as "bar" + baz` is not unaliased + /// let expr = col("foo").alias("bar") + col("baz"); + /// assert_eq!(expr.clone().unalias_nested(), expr); + /// + /// // `foo as "bar" as "baz" is unalaised to foo + /// let expr = col("foo").alias("bar").alias("baz"); + /// assert_eq!(expr.unalias_nested(), col("foo")); + /// ``` + pub fn unalias_nested(self) -> Expr { + match self { + Expr::Alias(alias) => alias.expr.unalias_nested(), + _ => self, + } + } + /// Return `self IN ` if `negated` is false, otherwise /// return `self NOT IN `.a pub fn in_list(self, list: Vec, negated: bool) -> Expr { diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 0f2aaa6cbcb38..5a9705381d7f2 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -26,18 +26,22 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{ - get_required_group_by_exprs_indices, internal_err, Column, JoinType, Result, + get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, Column, + JoinType, Result, }; -use datafusion_expr::expr::{Alias, ScalarFunction}; +use datafusion_expr::expr::Alias; use datafusion_expr::{ - logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct, - Expr, Projection, TableScan, Window, + logical_plan::LogicalPlan, projection_schema, Aggregate, Distinct, Expr, Projection, + TableScan, Window, }; use crate::optimize_projections::required_indices::RequiredIndicies; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use crate::utils::NamePreserver; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, +}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use hashbrown::HashMap; -use itertools::izip; /// Optimizer rule to prune unnecessary columns from intermediate schemas /// inside the [`LogicalPlan`]. This rule: @@ -67,12 +71,10 @@ impl OptimizeProjections { impl OptimizerRule for OptimizeProjections { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { - // All output fields are necessary: - let indices = RequiredIndicies::new_for_all_exprs(plan); - optimize_projections(plan, config, indices) + internal_err!("Should have called OptimizeProjections::rewrite") } fn name(&self) -> &str { @@ -82,6 +84,20 @@ impl OptimizerRule for OptimizeProjections { fn apply_order(&self) -> Option { None } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // All output fields are necessary: + let indices = RequiredIndicies::new_for_all_exprs(&plan); + optimize_projections(plan, config, indices) + } } /// Removes unnecessary columns (e.g. columns that do not appear in the output @@ -93,7 +109,7 @@ impl OptimizerRule for OptimizeProjections { /// - `plan`: A reference to the input `LogicalPlan` to optimize. /// - `config`: A reference to the optimizer configuration. /// - `indices`: A slice of column indices that represent the necessary column -/// indices for downstream operations. +/// indices for downstream (parent) plan nodes. /// /// # Returns /// @@ -102,101 +118,19 @@ impl OptimizerRule for OptimizeProjections { /// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary /// columns. /// - `Ok(None)`: Signal that the given logical plan did not require any change. -/// - `Err(error)`: An error occured during the optimization process. +/// - `Err(error)`: An error occurred during the optimization process. fn optimize_projections( - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, indices: RequiredIndicies, -) -> Result> { - let child_required_indices: Vec = match plan { - LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::Union(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Distinct(Distinct::On(_)) => { - // Pass index requirements from the parent as well as column indices - // that appear in this plan's expressions to its child. All these - // operators benefit from "small" inputs, so the projection_beneficial - // flag is `true`. - plan.inputs() - .into_iter() - .map(|input| { - indices - .clone() - .with_projection_beneficial() - .with_plan_exprs(plan, input.schema()) - }) - .collect::>()? - } - LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { - // Pass index requirements from the parent as well as column indices - // that appear in this plan's expressions to its child. These operators - // do not benefit from "small" inputs, so the projection_beneficial - // flag is `false`. - plan.inputs() - .into_iter() - .map(|input| indices.clone().with_plan_exprs(plan, input.schema())) - .collect::>()? - } - LogicalPlan::Copy(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::Distinct(Distinct::All(_)) => { - // These plans require all their fields, and their children should - // be treated as final plans -- otherwise, we may have schema a - // mismatch. - // TODO: For some subquery variants (e.g. a subquery arising from an - // EXISTS expression), we may not need to require all indices. - plan.inputs() - .into_iter() - .map(RequiredIndicies::new_for_all_exprs) - .collect() - } - LogicalPlan::Extension(extension) => { - let Some(necessary_children_indices) = - extension.node.necessary_children_exprs(indices.indices()) - else { - // Requirements from parent cannot be routed down to user defined logical plan safely - return Ok(None); - }; - let children = extension.node.inputs(); - if children.len() != necessary_children_indices.len() { - return internal_err!("Inconsistent length between children and necessary children indices. \ - Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \ - consistent with actual children length for the node."); - } - children - .into_iter() - .zip(necessary_children_indices) - .map(|(child, necessary_indices)| { - RequiredIndicies::new_from_indices(necessary_indices) - .with_plan_exprs(plan, child.schema()) - }) - .collect::>>()? - } - LogicalPlan::EmptyRelation(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Statement(_) - | LogicalPlan::Values(_) - | LogicalPlan::DescribeTable(_) => { - // These operators have no inputs, so stop the optimization process. - return Ok(None); - } +) -> Result> { + // Recursively rewrite any nodes that may be able to avoid computation given + // their parents' required indices. + match plan { LogicalPlan::Projection(proj) => { - return if let Some(proj) = merge_consecutive_projections(proj)? { - Ok(Some( - rewrite_projection_given_requirements(&proj, config, indices)? - // Even if we cannot optimize the projection, merge if possible: - .unwrap_or_else(|| LogicalPlan::Projection(proj)), - )) - } else { + return merge_consecutive_projections(proj)?.transform_data(|proj| { rewrite_projection_given_requirements(proj, config, indices) - }; + }) } LogicalPlan::Aggregate(aggregate) => { // Split parent requirements to GROUP BY and aggregate sections: @@ -211,6 +145,7 @@ fn optimize_projections( .iter() .map(|group_by_expr| group_by_expr.display_name()) .collect::>>()?; + let new_group_bys = if let Some(simplest_groupby_indices) = get_required_group_by_exprs_indices( aggregate.input.schema(), @@ -223,7 +158,7 @@ fn optimize_projections( .append(&simplest_groupby_indices) .get_at_indices(&aggregate.group_expr) } else { - aggregate.group_expr.clone() + aggregate.group_expr }; // Only use the absolutely necessary aggregate expressions required @@ -242,7 +177,9 @@ fn optimize_projections( && new_group_bys.is_empty() && !aggregate.aggr_expr.is_empty() { - new_aggr_expr = vec![aggregate.aggr_expr[0].clone()]; + // take the old, first aggregate expression + new_aggr_expr = aggregate.aggr_expr; + new_aggr_expr.resize_with(1, || unreachable!()); } let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); @@ -251,32 +188,31 @@ fn optimize_projections( RequiredIndicies::new().with_exprs(schema, all_exprs_iter)?; let necessary_exprs = necessary_indices.get_required_exprs(schema); - let aggregate_input = if let Some(input) = - optimize_projections(&aggregate.input, config, necessary_indices)? - { - input - } else { - aggregate.input.as_ref().clone() - }; - - // Simplify the input of the aggregation by adding a projection so - // that its input only contains absolutely necessary columns for - // the aggregate expressions. Note that necessary_indices refer to - // fields in `aggregate.input.schema()`. - let (aggregate_input, _) = - add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)?; - - // Create a new aggregate plan with the updated input and only the - // absolutely necessary fields: - return Aggregate::try_new( - Arc::new(aggregate_input), - new_group_bys, - new_aggr_expr, - ) - .map(|aggregate| Some(LogicalPlan::Aggregate(aggregate))); + return optimize_projections( + unwrap_arc(aggregate.input), + config, + necessary_indices, + )? + .transform_data(|aggregate_input| { + // Simplify the input of the aggregation by adding a projection so + // that its input only contains absolutely necessary columns for + // the aggregate expressions. Note that necessary_indices refer to + // fields in `aggregate.input.schema()`. + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs) + })? + .map_data(|aggregate_input| { + // Create a new aggregate plan with the updated input and only the + // absolutely necessary fields: + Aggregate::try_new( + Arc::new(aggregate_input), + new_group_bys, + new_aggr_expr, + ) + .map(LogicalPlan::Aggregate) + }); } LogicalPlan::Window(window) => { - let input_schema = window.input.schema(); + let input_schema = window.input.schema().clone(); // Split parent requirements to child and window expression sections: let n_input_fields = input_schema.fields().len(); // Offset window expression indices so that they point to valid @@ -290,38 +226,152 @@ fn optimize_projections( // Get all the required column indices at the input, either by the // parent or window expression requirements. let required_indices = - child_reqs.with_exprs(input_schema, &new_window_expr)?; + child_reqs.with_exprs(&input_schema, &new_window_expr)?; - let window_child = if let Some(new_window_child) = - optimize_projections(&window.input, config, required_indices.clone())? - { - new_window_child - } else { - window.input.as_ref().clone() + return optimize_projections( + unwrap_arc(window.input), + config, + required_indices.clone(), + )? + .transform_data(|window_child| { + if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: + Ok(Transformed::no(window_child)) + } else { + // Calculate required expressions at the input of the window. + // Please note that we use `input_schema`, because `required_indices` + // refers to that schema + let required_exprs = + required_indices.get_required_exprs(&input_schema); + let window_child = + add_projection_on_top_if_helpful(window_child, required_exprs)? + .data; + Window::try_new(new_window_expr, Arc::new(window_child)) + .map(LogicalPlan::Window) + .map(Transformed::yes) + } + }); + } + LogicalPlan::TableScan(table_scan) => { + let TableScan { + table_name, + source, + projection, + filters, + fetch, + projected_schema: _, + } = table_scan; + + // Get indices referred to in the original (schema with all fields) + // given projected indices. + let projection = match &projection { + Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), + None => indices.into_inner(), }; + return TableScan::try_new( + table_name, + source, + Some(projection), + filters, + fetch, + ) + .map(LogicalPlan::TableScan) + .map(Transformed::yes); + } - return if new_window_expr.is_empty() { - // When no window expression is necessary, use the input directly: - Ok(Some(window_child)) - } else { - // Calculate required expressions at the input of the window. - // Please note that we use `old_child`, because `required_indices` - // refers to `old_child`. - let required_exprs = required_indices.get_required_exprs(input_schema); - let (window_child, _) = - add_projection_on_top_if_helpful(window_child, required_exprs)?; - Window::try_new(new_window_expr, Arc::new(window_child)) - .map(|window| Some(LogicalPlan::Window(window))) + // Other node types are handled below + _ => {} + }; + + // For other plan node types, calculate indices for columns they use and + // try to rewrite their children + let mut child_required_indices: Vec = match &plan { + LogicalPlan::Sort(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Unnest(_) + | LogicalPlan::Union(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Distinct(Distinct::On(_)) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. All these + // operators benefit from "small" inputs, so the projection_beneficial + // flag is `true`. + plan.inputs() + .into_iter() + .map(|input| { + indices + .clone() + .with_projection_beneficial() + .with_plan_exprs(&plan, input.schema()) + }) + .collect::>()? + } + LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. + plan.inputs() + .into_iter() + .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) + .collect::>()? + } + LogicalPlan::Copy(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::Distinct(Distinct::All(_)) => { + // These plans require all their fields, and their children should + // be treated as final plans -- otherwise, we may have schema a + // mismatch. + // TODO: For some subquery variants (e.g. a subquery arising from an + // EXISTS expression), we may not need to require all indices. + plan.inputs() + .into_iter() + .map(RequiredIndicies::new_for_all_exprs) + .collect() + } + LogicalPlan::Extension(extension) => { + let Some(necessary_children_indices) = + extension.node.necessary_children_exprs(indices.indices()) + else { + // Requirements from parent cannot be routed down to user defined logical plan safely + return Ok(Transformed::no(plan)); }; + let children = extension.node.inputs(); + if children.len() != necessary_children_indices.len() { + return internal_err!("Inconsistent length between children and necessary children indices. \ + Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \ + consistent with actual children length for the node."); + } + children + .into_iter() + .zip(necessary_children_indices) + .map(|(child, necessary_indices)| { + RequiredIndicies::new_from_indices(necessary_indices) + .with_plan_exprs(&plan, child.schema()) + }) + .collect::>>()? + } + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Values(_) + | LogicalPlan::DescribeTable(_) => { + // These operators have no inputs, so stop the optimization process. + return Ok(Transformed::no(plan)); } LogicalPlan::Join(join) => { let left_len = join.left.schema().fields().len(); let (left_req_indices, right_req_indices) = split_join_requirements(left_len, indices, &join.join_type); let left_indices = - left_req_indices.with_plan_exprs(plan, join.left.schema())?; + left_req_indices.with_plan_exprs(&plan, join.left.schema())?; let right_indices = - right_req_indices.with_plan_exprs(plan, join.right.schema())?; + right_req_indices.with_plan_exprs(&plan, join.right.schema())?; // Joins benefit from "small" input tables (lower memory usage). // Therefore, each child benefits from projection: vec![ @@ -340,55 +390,53 @@ fn optimize_projections( right_indices.with_projection_beneficial(), ] } - LogicalPlan::TableScan(table_scan) => { - // Get indices referred to in the original (schema with all fields) - // given projected indices. - let projection = match &table_scan.projection { - Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), - None => indices.into_inner(), - }; - return TableScan::try_new( - table_scan.table_name.clone(), - table_scan.source.clone(), - Some(projection), - table_scan.filters.clone(), - table_scan.fetch, - ) - .map(|table| Some(LogicalPlan::TableScan(table))); + // these nodes are explicitly rewritten in the match statement above + LogicalPlan::Projection(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Window(_) + | LogicalPlan::TableScan(_) => { + return internal_err!( + "OptimizeProjection: should have handled in the match statement above" + ); } }; - let new_inputs = izip!(child_required_indices, plan.inputs().into_iter()) - .map(|(required_indices, child)| { - let projection_beneficial = required_indices.projection_beneficial(); - let project_exprs = required_indices.get_required_exprs(child.schema()); - let (input, is_changed) = if let Some(new_input) = - optimize_projections(child, config, required_indices)? - { - (new_input, true) - } else { - (child.clone(), false) - }; - let (input, proj_added) = if projection_beneficial { - add_projection_on_top_if_helpful(input, project_exprs)? - } else { - (input, false) - }; - Ok((is_changed || proj_added).then_some(input)) - }) - .collect::>>()?; - if new_inputs.iter().all(|child| child.is_none()) { - // All children are the same in this case, no need to change the plan: - Ok(None) + // Required indices are currently ordered (child0, child1, ...) + // but the loop pops off the last element, so we need to reverse the order + child_required_indices.reverse(); + if child_required_indices.len() != plan.inputs().len() { + return internal_err!( + "OptimizeProjection: child_required_indices length mismatch with plan inputs" + ); + } + + // Rewrite children of the plan + let transformed_plan = plan.map_children(|child| { + let required_indices = child_required_indices.pop().ok_or_else(|| { + internal_datafusion_err!( + "Unexpected number of required_indices in OptimizeProjections rule" + ) + })?; + + let projection_beneficial = required_indices.projection_beneficial(); + let project_exprs = required_indices.get_required_exprs(child.schema()); + + optimize_projections(child, config, required_indices)?.transform_data( + |new_input| { + if projection_beneficial { + add_projection_on_top_if_helpful(new_input, project_exprs) + } else { + Ok(Transformed::no(new_input)) + } + }, + ) + })?; + + // If any of the children are transformed, we need to potentially update the plan's schema + if transformed_plan.transformed { + transformed_plan.map_data(|plan| plan.recompute_schema()) } else { - // At least one of the children is changed: - let new_inputs = izip!(new_inputs, plan.inputs()) - // If new_input is `None`, this means child is not changed, so use - // `old_child` during construction: - .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) - .collect(); - let exprs = plan.expressions(); - plan.with_new_exprs(exprs, new_inputs).map(Some) + Ok(transformed_plan) } } @@ -412,22 +460,28 @@ fn optimize_projections( /// merged projection. /// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). /// - `Err(error)`: An error occured during the function call. -fn merge_consecutive_projections(proj: &Projection) -> Result> { - let LogicalPlan::Projection(prev_projection) = proj.input.as_ref() else { - return Ok(None); +fn merge_consecutive_projections(proj: Projection) -> Result> { + let Projection { + expr, + input, + schema, + .. + } = proj; + let LogicalPlan::Projection(prev_projection) = input.as_ref() else { + return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); }; // Count usages (referrals) of each projection expression in its input fields: let mut column_referral_map = HashMap::::new(); - for columns in proj.expr.iter().flat_map(|expr| expr.to_columns()) { + for columns in expr.iter().flat_map(|expr| expr.to_columns()) { for col in columns.into_iter() { *column_referral_map.entry(col.clone()).or_default() += 1; } } - // If an expression is non-trivial and appears more than once, consecutive - // projections will benefit from a compute-once approach. For details, see: - // https://github.com/apache/datafusion/issues/8296 + // If an expression is non-trivial and appears more than once, do not merge + // them as consecutive projections will benefit from a compute-once approach. + // For details, see: https://github.com/apache/datafusion/issues/8296 if column_referral_map.into_iter().any(|(col, usage)| { usage > 1 && !is_expr_trivial( @@ -435,33 +489,78 @@ fn merge_consecutive_projections(proj: &Projection) -> Result [prev_projection.schema.index_of_column(&col).unwrap()], ) }) { - return Ok(None); + // no change + return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); } - // If all the expression of the top projection can be rewritten, do so and - // create a new projection: - let new_exprs = proj - .expr - .iter() - .map(|expr| rewrite_expr(expr, prev_projection)) - .collect::>>>()?; - if let Some(new_exprs) = new_exprs { + let LogicalPlan::Projection(prev_projection) = unwrap_arc(input) else { + // We know it is a `LogicalPlan::Projection` from check above + unreachable!(); + }; + + // Try to rewrite the expressions in the current projection using the + // previous projection as input: + let name_preserver = NamePreserver::new_for_projection(); + let mut original_names = vec![]; + let new_exprs = expr.into_iter().map_until_stop_and_collect(|expr| { + original_names.push(name_preserver.save(&expr)?); + + // do not rewrite top level Aliases (rewriter will remove all aliases within exprs) + match expr { + Expr::Alias(Alias { + expr, + relation, + name, + }) => rewrite_expr(*expr, &prev_projection).map(|result| { + result.update_data(|expr| Expr::Alias(Alias::new(expr, relation, name))) + }), + e => rewrite_expr(e, &prev_projection), + } + })?; + + // if the expressions could be rewritten, create a new projection with the + // new expressions + if new_exprs.transformed { + // Add any needed aliases back to the expressions let new_exprs = new_exprs + .data .into_iter() - .zip(proj.expr.iter()) - .map(|(new_expr, old_expr)| { - new_expr.alias_if_changed(old_expr.name_for_alias()?) - }) + .zip(original_names.into_iter()) + .map(|(expr, original_name)| original_name.restore(expr)) .collect::>>()?; - Projection::try_new(new_exprs, prev_projection.input.clone()).map(Some) + Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes) } else { - Ok(None) + // not rewritten, so put the projection back together + let input = Arc::new(LogicalPlan::Projection(prev_projection)); + Projection::try_new_with_schema(new_exprs.data, input, schema) + .map(Transformed::no) } } -/// Trim the given expression by removing any unnecessary layers of aliasing. -/// If the expression is an alias, the function returns the underlying expression. -/// Otherwise, it returns the given expression as is. +// Check whether `expr` is trivial; i.e. it doesn't imply any computation. +fn is_expr_trivial(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} + +/// Rewrites a projection expression using the projection before it (i.e. its input) +/// This is a subroutine to the `merge_consecutive_projections` function. +/// +/// # Parameters +/// +/// * `expr` - A reference to the expression to rewrite. +/// * `input` - A reference to the input of the projection expression (itself +/// a projection). +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. +/// - `Ok(None)`: Signals that `expr` can not be rewritten. +/// - `Err(error)`: An error occurred during the function call. +/// +/// # Notes +/// This rewrite also removes any unnecessary layers of aliasing. /// /// Without trimming, we can end up with unnecessary indirections inside expressions /// during projection merges. @@ -487,84 +586,28 @@ fn merge_consecutive_projections(proj: &Projection) -> Result /// Projection((a as a1 + b as b1) as sum1) /// --Source(a, b) /// ``` -fn trim_expr(expr: Expr) -> Expr { - match expr { - Expr::Alias(alias) => trim_expr(*alias.expr), - _ => expr, - } -} - -// Check whether `expr` is trivial; i.e. it doesn't imply any computation. -fn is_expr_trivial(expr: &Expr) -> bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) -} - -// Exit early when there is no rewrite to do. -macro_rules! rewrite_expr_with_check { - ($expr:expr, $input:expr) => { - if let Some(value) = rewrite_expr($expr, $input)? { - value - } else { - return Ok(None); - } - }; -} - -/// Rewrites a projection expression using the projection before it (i.e. its input) -/// This is a subroutine to the `merge_consecutive_projections` function. -/// -/// # Parameters -/// -/// * `expr` - A reference to the expression to rewrite. -/// * `input` - A reference to the input of the projection expression (itself -/// a projection). -/// -/// # Returns -/// -/// A `Result` object with the following semantics: -/// -/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. -/// - `Ok(None)`: Signals that `expr` can not be rewritten. -/// - `Err(error)`: An error occurred during the function call. -fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { - let result = match expr { - Expr::Column(col) => { - // Find index of column: - let idx = input.schema.index_of_column(col)?; - input.expr[idx].clone() - } - Expr::BinaryExpr(binary) => Expr::BinaryExpr(BinaryExpr::new( - Box::new(trim_expr(rewrite_expr_with_check!(&binary.left, input))), - binary.op, - Box::new(trim_expr(rewrite_expr_with_check!(&binary.right, input))), - )), - Expr::Alias(alias) => Expr::Alias(Alias::new( - trim_expr(rewrite_expr_with_check!(&alias.expr, input)), - alias.relation.clone(), - alias.name.clone(), - )), - Expr::Literal(_) => expr.clone(), - Expr::Cast(cast) => { - let new_expr = rewrite_expr_with_check!(&cast.expr, input); - Expr::Cast(Cast::new(Box::new(new_expr), cast.data_type.clone())) - } - Expr::ScalarFunction(scalar_fn) => { - return Ok(scalar_fn - .args - .iter() - .map(|expr| rewrite_expr(expr, input)) - .collect::>>()? - .map(|new_args| { - Expr::ScalarFunction(ScalarFunction::new_func_def( - scalar_fn.func_def.clone(), - new_args, - )) - })); +fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { + expr.transform_up(|expr| { + match expr { + // remove any intermediate aliases + Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + Expr::Column(col) => { + // Find index of column: + let idx = input.schema.index_of_column(&col)?; + // get the corresponding unaliased input expression + // + // For example: + // * the input projection is [`a + b` as c, `d + e` as f] + // * the current column is an expression "f" + // + // return the expression `d + e` (not `d + e` as f) + let input_expr = input.expr[idx].clone().unalias_nested(); + Ok(Transformed::yes(input_expr)) + } + // Unsupported type for consecutive projection merge analysis. + _ => Ok(Transformed::no(expr)), } - // Unsupported type for consecutive projection merge analysis. - _ => return Ok(None), - }; - Ok(Some(result)) + }) } /// Accumulates outer-referenced columns by the @@ -682,19 +725,18 @@ fn split_join_requirements( /// /// # Returns /// -/// A `Result` containing a tuple with two values: The resulting `LogicalPlan` -/// (with or without the added projection) and a `bool` flag indicating if a -/// projection was added (`true`) or not (`false`). +/// A `Transformed` indicating if a projection was added fn add_projection_on_top_if_helpful( plan: LogicalPlan, project_exprs: Vec, -) -> Result<(LogicalPlan, bool)> { +) -> Result> { // Make sure projection decreases the number of columns, otherwise it is unnecessary. if project_exprs.len() >= plan.schema().fields().len() { - Ok((plan, false)) + Ok(Transformed::no(plan)) } else { Projection::try_new(project_exprs, Arc::new(plan)) - .map(|proj| (LogicalPlan::Projection(proj), true)) + .map(LogicalPlan::Projection) + .map(Transformed::yes) } } @@ -716,37 +758,30 @@ fn add_projection_on_top_if_helpful( /// - `Ok(None)`: No rewrite necessary. /// - `Err(error)`: An error occured during the function call. fn rewrite_projection_given_requirements( - proj: &Projection, + proj: Projection, config: &dyn OptimizerConfig, indices: RequiredIndicies, -) -> Result> { - let exprs_used = indices.get_at_indices(&proj.expr); +) -> Result> { + let Projection { expr, input, .. } = proj; + + let exprs_used = indices.get_at_indices(&expr); let required_indices = - RequiredIndicies::new().with_exprs(proj.input.schema(), exprs_used.iter())?; - return if let Some(input) = - optimize_projections(&proj.input, config, required_indices)? - { - if is_projection_unnecessary(&input, &exprs_used)? { - Ok(Some(input)) - } else { - Projection::try_new(exprs_used, Arc::new(input)) - .map(|proj| Some(LogicalPlan::Projection(proj))) - } - } else if exprs_used.len() < proj.expr.len() { - // Projection expression used is different than the existing projection. - // In this case, even if the child doesn't change, we should update the - // projection to use fewer columns: - if is_projection_unnecessary(&proj.input, &exprs_used)? { - Ok(Some(proj.input.as_ref().clone())) - } else { - Projection::try_new(exprs_used, proj.input.clone()) - .map(|proj| Some(LogicalPlan::Projection(proj))) - } - } else { - // Projection doesn't change. - Ok(None) - }; + RequiredIndicies::new().with_exprs(input.schema(), exprs_used.iter())?; + + // rewrite the children projection, and if they are changed rewrite the + // projection down + optimize_projections(unwrap_arc(input), config, required_indices)?.transform_data( + |input| { + if is_projection_unnecessary(&input, &exprs_used)? { + Ok(Transformed::yes(input)) + } else { + Projection::try_new(exprs_used, Arc::new(input)) + .map(LogicalPlan::Projection) + .map(Transformed::yes) + } + }, + ) } /// Projection is unnecessary, when @@ -761,6 +796,7 @@ fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result mod tests { use std::collections::HashMap; use std::fmt::Formatter; + use std::ops::Add; use std::sync::Arc; use std::vec; @@ -1184,13 +1220,32 @@ mod tests { assert_optimized_plan_equal(plan, expected) } + // Test Case expression + #[test] + fn test_case_merged() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), lit(0).alias("d")])? + .project(vec![ + col("a"), + when(col("a").eq(lit(1)), lit(10)) + .otherwise(col("d"))? + .alias("d"), + ])? + .build()?; + + let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(plan, expected) + } + // Test outer projection isn't discarded despite the same schema as inner // https://github.com/apache/datafusion/issues/8942 #[test] fn test_derived_column() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), lit(0).alias("d")])? + .project(vec![col("a").add(lit(1)).alias("a"), lit(0).alias("d")])? .project(vec![ col("a"), when(col("a").eq(lit(1)), lit(10)) @@ -1199,8 +1254,9 @@ mod tests { ])? .build()?; - let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE d END AS d\ - \n Projection: test.a, Int32(0) AS d\ + let expected = + "Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d\ + \n Projection: test.a + Int32(1) AS a, Int32(0) AS d\ \n TableScan: test projection=[a]"; assert_optimized_plan_equal(plan, expected) } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 1c20501da53a3..fd47cb23b108d 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -295,6 +295,13 @@ impl NamePreserver { } } + /// Create a new NamePreserver for rewriting the `expr`s in `Projection` + /// + /// This will use aliases + pub fn new_for_projection() -> Self { + Self { use_alias: true } + } + pub fn save(&self, expr: &Expr) -> Result { let original_name = if self.use_alias { Some(expr.name_for_alias()?) diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index e9c508cb27fb1..05491772999e9 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -31,10 +31,9 @@ query TT EXPLAIN WITH "NUMBERS" AS (SELECT 1 as a, 2 as b, 3 as c) SELECT "NUMBERS".* FROM "NUMBERS" ---- logical_plan -01)Projection: NUMBERS.a, NUMBERS.b, NUMBERS.c -02)--SubqueryAlias: NUMBERS -03)----Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c -04)------EmptyRelation +01)SubqueryAlias: NUMBERS +02)--Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c +03)----EmptyRelation physical_plan 01)ProjectionExec: expr=[1 as a, 2 as b, 3 as c] 02)--PlaceholderRowExec @@ -105,14 +104,13 @@ EXPLAIN WITH RECURSIVE nodes AS ( SELECT * FROM nodes ---- logical_plan -01)Projection: nodes.id -02)--SubqueryAlias: nodes -03)----RecursiveQuery: is_distinct=false -04)------Projection: Int64(1) AS id -05)--------EmptyRelation -06)------Projection: nodes.id + Int64(1) AS id -07)--------Filter: nodes.id < Int64(10) -08)----------TableScan: nodes +01)SubqueryAlias: nodes +02)--RecursiveQuery: is_distinct=false +03)----Projection: Int64(1) AS id +04)------EmptyRelation +05)----Projection: nodes.id + Int64(1) AS id +06)------Filter: nodes.id < Int64(10) +07)--------TableScan: nodes physical_plan 01)RecursiveQueryExec: name=nodes, is_distinct=false 02)--ProjectionExec: expr=[1 as id] @@ -152,14 +150,13 @@ ORDER BY time, name, account_balance ---- logical_plan 01)Sort: balances.time ASC NULLS LAST, balances.name ASC NULLS LAST, balances.account_balance ASC NULLS LAST -02)--Projection: balances.time, balances.name, balances.account_balance -03)----SubqueryAlias: balances -04)------RecursiveQuery: is_distinct=false -05)--------Projection: balance.time, balance.name, balance.account_balance -06)----------TableScan: balance -07)--------Projection: balances.time + Int64(1) AS time, balances.name, balances.account_balance + Int64(10) AS account_balance -08)----------Filter: balances.time < Int64(10) -09)------------TableScan: balances +02)--SubqueryAlias: balances +03)----RecursiveQuery: is_distinct=false +04)------Projection: balance.time, balance.name, balance.account_balance +05)--------TableScan: balance +06)------Projection: balances.time + Int64(1) AS time, balances.name, balances.account_balance + Int64(10) AS account_balance +07)--------Filter: balances.time < Int64(10) +08)----------TableScan: balances physical_plan 01)SortExec: expr=[time@0 ASC NULLS LAST,name@1 ASC NULLS LAST,account_balance@2 ASC NULLS LAST], preserve_partitioning=[false] 02)--RecursiveQueryExec: name=balances, is_distinct=false @@ -720,18 +717,17 @@ explain WITH RECURSIVE recursive_cte AS ( SELECT * FROM recursive_cte; ---- logical_plan -01)Projection: recursive_cte.val -02)--SubqueryAlias: recursive_cte -03)----RecursiveQuery: is_distinct=false -04)------Projection: Int64(1) AS val -05)--------EmptyRelation -06)------Projection: Int64(2) AS val -07)--------CrossJoin: -08)----------Filter: recursive_cte.val < Int64(2) -09)------------TableScan: recursive_cte -10)----------SubqueryAlias: sub_cte -11)------------Projection: Int64(2) AS val -12)--------------EmptyRelation +01)SubqueryAlias: recursive_cte +02)--RecursiveQuery: is_distinct=false +03)----Projection: Int64(1) AS val +04)------EmptyRelation +05)----Projection: Int64(2) AS val +06)------CrossJoin: +07)--------Filter: recursive_cte.val < Int64(2) +08)----------TableScan: recursive_cte +09)--------SubqueryAlias: sub_cte +10)----------Projection: Int64(2) AS val +11)------------EmptyRelation physical_plan 01)RecursiveQueryExec: name=recursive_cte, is_distinct=false 02)--ProjectionExec: expr=[1 as val] diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index f4a86fd02709c..e2c77552f9903 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3602,17 +3602,16 @@ EXPLAIN SELECT * FROM ( ) as a FULL JOIN (SELECT 1 as e, 3 AS f) AS rhs ON a.c=rhs.e; ---- logical_plan -01)Projection: a.c, a.d, rhs.e, rhs.f -02)--Full Join: a.c = rhs.e -03)----SubqueryAlias: a -04)------Union -05)--------Projection: Int64(1) AS c, Int64(2) AS d -06)----------EmptyRelation -07)--------Projection: Int64(1) AS c, Int64(3) AS d -08)----------EmptyRelation -09)----SubqueryAlias: rhs -10)------Projection: Int64(1) AS e, Int64(3) AS f -11)--------EmptyRelation +01)Full Join: a.c = rhs.e +02)--SubqueryAlias: a +03)----Union +04)------Projection: Int64(1) AS c, Int64(2) AS d +05)--------EmptyRelation +06)------Projection: Int64(1) AS c, Int64(3) AS d +07)--------EmptyRelation +08)--SubqueryAlias: rhs +09)----Projection: Int64(1) AS e, Int64(3) AS f +10)------EmptyRelation physical_plan 01)ProjectionExec: expr=[c@2 as c, d@3 as d, e@0 as e, f@1 as f] 02)--CoalesceBatchesExec: target_batch_size=2 @@ -3650,17 +3649,16 @@ EXPLAIN SELECT * FROM ( ) as a FULL JOIN (SELECT 1 as e, 3 AS f) AS rhs ON a.c=rhs.e; ---- logical_plan -01)Projection: a.c, a.d, rhs.e, rhs.f -02)--Full Join: a.c = rhs.e -03)----SubqueryAlias: a -04)------Union -05)--------Projection: Int64(1) AS c, Int64(2) AS d -06)----------EmptyRelation -07)--------Projection: Int64(1) AS c, Int64(3) AS d -08)----------EmptyRelation -09)----SubqueryAlias: rhs -10)------Projection: Int64(1) AS e, Int64(3) AS f -11)--------EmptyRelation +01)Full Join: a.c = rhs.e +02)--SubqueryAlias: a +03)----Union +04)------Projection: Int64(1) AS c, Int64(2) AS d +05)--------EmptyRelation +06)------Projection: Int64(1) AS c, Int64(3) AS d +07)--------EmptyRelation +08)--SubqueryAlias: rhs +09)----Projection: Int64(1) AS e, Int64(3) AS f +10)------EmptyRelation physical_plan 01)ProjectionExec: expr=[c@2 as c, d@3 as d, e@0 as e, f@1 as f] 02)--CoalesceBatchesExec: target_batch_size=2