diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 10479f29a583..9aa226853cb2 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, @@ -31,8 +31,8 @@ use datafusion_expr::expr::{ self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, WindowFunction, }; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ comparison_coercion, get_input_types, like_coercion, @@ -51,6 +51,7 @@ use datafusion_expr::{ }; use crate::analyzer::AnalyzerRule; +use crate::utils::NamePreserver; #[derive(Default)] pub struct TypeCoercion {} @@ -67,26 +68,28 @@ impl AnalyzerRule for TypeCoercion { } fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - analyze_internal(&DFSchema::empty(), &plan) + let empty_schema = DFSchema::empty(); + + let transformed_plan = plan + .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))? + .data; + + Ok(transformed_plan) } } +/// use the external schema to handle the correlated subqueries case +/// +/// Assumes that children have already been optimized fn analyze_internal( - // use the external schema to handle the correlated subqueries case external_schema: &DFSchema, - plan: &LogicalPlan, -) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| analyze_internal(external_schema, p)) - .collect::>>()?; + plan: LogicalPlan, +) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); + let mut schema = merge_schema(plan.inputs()); - if let LogicalPlan::TableScan(ts) = plan { + if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -99,25 +102,75 @@ fn analyze_internal( // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) schema.merge(external_schema); - let mut expr_rewrite = TypeCoercionRewriter { schema: &schema }; - - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure aggregate names don't change: - // https://github.com/apache/datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; - - plan.with_new_exprs(new_expr, new_inputs) + let mut expr_rewrite = TypeCoercionRewriter::new(&schema); + + let name_preserver = NamePreserver::new(&plan); + // apply coercion rewrite all expressions in the plan indivdually + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + expr.rewrite(&mut expr_rewrite)? + .map_data(|expr| original_name.restore(expr)) + })? + // coerce join expressions specially + .map_data(|plan| expr_rewrite.coerce_joins(plan))? + // recompute the schema after the expressions have been rewritten as the types may have changed + .map_data(|plan| plan.recompute_schema()) } pub(crate) struct TypeCoercionRewriter<'a> { pub(crate) schema: &'a DFSchema, } +impl<'a> TypeCoercionRewriter<'a> { + fn new(schema: &'a DFSchema) -> Self { + Self { schema } + } + + /// Coerce join equality expressions + /// + /// Joins must be treated specially as their equality expressions are stored + /// as a parallel list of left and right expressions, rather than a single + /// equality expression + /// + /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored + /// as a list of `(t1.a, t2.b), (t1.x, t2.y)` + fn coerce_joins(&mut self, plan: LogicalPlan) -> Result { + let LogicalPlan::Join(mut join) = plan else { + return Ok(plan); + }; + + join.on = join + .on + .into_iter() + .map(|(lhs, rhs)| { + // coerce the arguments as though they were a single binary equality + // expression + let (lhs, rhs) = self.coerce_binary_op(lhs, Operator::Eq, rhs)?; + Ok((lhs, rhs)) + }) + .collect::>>()?; + + Ok(LogicalPlan::Join(join)) + } + + fn coerce_binary_op( + &self, + left: Expr, + op: Operator, + right: Expr, + ) -> Result<(Expr, Expr)> { + let (left_type, right_type) = get_input_types( + &left.get_type(self.schema)?, + &op, + &right.get_type(self.schema)?, + )?; + Ok(( + left.cast_to(&left_type, self.schema)?, + right.cast_to(&right_type, self.schema)?, + )) + } +} + impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { type Node = Expr; @@ -130,14 +183,15 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { subquery, outer_ref_columns, }) => { - let new_plan = analyze_internal(self.schema, &subquery)?; + let new_plan = analyze_internal(self.schema, unwrap_arc(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, }))) } Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(self.schema, &subquery.subquery)?; + let new_plan = + analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data; Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), @@ -151,7 +205,8 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { subquery, negated, }) => { - let new_plan = analyze_internal(self.schema, &subquery.subquery)?; + let new_plan = + analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data; let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( @@ -220,15 +275,11 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let (left_type, right_type) = get_input_types( - &left.get_type(self.schema)?, - &op, - &right.get_type(self.schema)?, - )?; + let (left, right) = self.coerce_binary_op(*left, op, *right)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, self.schema)?), + Box::new(left), op, - Box::new(right.cast_to(&right_type, self.schema)?), + Box::new(right), )))) } Expr::Between(Between {