Skip to content

Commit

Permalink
Avoid copies in TypeCoercion via TreeNode API
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Apr 23, 2024
1 parent 15045a8 commit a98adc7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
6 changes: 6 additions & 0 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,12 @@ impl<T> Transformed<T> {
Self::new(data, false, TreeNodeRecursion::Continue)
}

/// If not already, sets `self.transformed` to true if `transformed` is true.
pub fn update_transformed(mut self, transformed: bool) -> Self {
self.transformed |= transformed;
self
}

/// Applies the given `f` to the data of this [`Transformed`] object.
pub fn update_data<U, F: FnOnce(T) -> U>(self, f: F) -> Transformed<U> {
Transformed::new(f(self.data), self.transformed, self.tnr)
Expand Down
54 changes: 29 additions & 25 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::Arc;
use arrow::datatypes::{DataType, IntervalUnit};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::{
exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef,
DataFusionError, Result, ScalarValue,
Expand All @@ -31,8 +31,8 @@ use datafusion_expr::expr::{
self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList,
InSubquery, Like, ScalarFunction, WindowFunction,
};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::expr_schema::cast_subquery;
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{
comparison_coercion, get_input_types, like_coercion,
Expand All @@ -51,6 +51,7 @@ use datafusion_expr::{
};

use crate::analyzer::AnalyzerRule;
use crate::utils::NamePreserver;

#[derive(Default)]
pub struct TypeCoercion {}
Expand All @@ -67,26 +68,27 @@ impl AnalyzerRule for TypeCoercion {
}

fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
analyze_internal(&DFSchema::empty(), &plan)
Ok(analyze_internal(&DFSchema::empty(), plan)?.data)
}
}

fn analyze_internal(
// use the external schema to handle the correlated subqueries case
external_schema: &DFSchema,
plan: &LogicalPlan,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| analyze_internal(external_schema, p))
.collect::<Result<Vec<_>>>()?;
plan: LogicalPlan,
) -> Result<Transformed<LogicalPlan>> {
// optimize child plans first (since we use external_schema here, can't use LogicalPlan::transform)
let Transformed {
data: plan,
transformed: children_transformed,
..
} = plan.map_children(|plan| analyze_internal(external_schema, plan))?;

// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(new_inputs.iter().collect());
let mut schema = merge_schema(plan.inputs());

if let LogicalPlan::TableScan(ts) = plan {
if let LogicalPlan::TableScan(ts) = &plan {
let source_schema = DFSchema::try_from_qualified_schema(
ts.table_name.clone(),
&ts.source.schema(),
Expand All @@ -103,17 +105,17 @@ fn analyze_internal(
schema: Arc::new(schema),
};

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
let preserver = NamePreserver::new(&plan);
Ok(plan
.map_expressions(|expr| {
// ensure aggregate names don't change:
// https://github.com/apache/datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;

plan.with_new_exprs(new_expr, new_inputs)
let original_name = preserver.save(&expr)?;
expr.rewrite(&mut expr_rewrite)?
.map_data(|expr| original_name.restore(expr))
})?
// propagate the the transformation information from children
.update_transformed(children_transformed))
}

pub(crate) struct TypeCoercionRewriter {
Expand All @@ -132,14 +134,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
outer_ref_columns,
}) => {
let new_plan = analyze_internal(&self.schema, &subquery)?;
let new_plan = analyze_internal(&self.schema, unwrap_arc(subquery))?.data;
Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns,
})))
}
Expr::Exists(Exists { subquery, negated }) => {
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
let new_plan =
analyze_internal(&self.schema, unwrap_arc(subquery.subquery))?.data;
Ok(Transformed::yes(Expr::Exists(Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
Expand All @@ -153,7 +156,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
negated,
}) => {
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
let new_plan =
analyze_internal(&self.schema, unwrap_arc(subquery.subquery))?.data;
let expr_type = expr.get_type(&self.schema)?;
let subquery_type = new_plan.schema().field(0).data_type();
let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
Expand Down

0 comments on commit a98adc7

Please sign in to comment.