diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 4d10298b60ef5..23505b4e52f79 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -17,19 +17,19 @@ //! Expression rewriter +use crate::expr::{Alias, Sort, Unnest}; +use crate::logical_plan::Projection; +use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; +use arrow::datatypes::DataType; use std::collections::HashMap; use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; -use crate::expr::{Alias, Sort, Unnest}; -use crate::logical_plan::Projection; -use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; - use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::TableReference; use datafusion_common::{Column, DFSchema, Result}; -use datafusion_common::{ScalarValue, TableReference}; mod order_by; pub use order_by::rewrite_sort_cols_by_aggs; @@ -150,14 +150,17 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul pub fn replace_expr_with_null( expr: Expr, - replace_columns: &HashSet<&Column>, + replace_columns: &HashMap<&Column, &DataType>, ) -> Result { expr.transform(|expr| { Ok({ match &expr { - Expr::Column(c) if replace_columns.contains(c) => { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } + Expr::Column(c) => match replace_columns.get(c) { + Some(data_type) => { + Transformed::yes(Expr::Literal((*data_type).try_into()?)) + } + None => Transformed::no(expr), + }, _ => Transformed::no(expr), } }) diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index a6ca1304a5457..95b2cc75b0d74 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -426,6 +426,7 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let fun = Arc::new(ScalarUDF::new_from_impl(DoNothingUdf::new())); + // eliminate to inner join let plan = LogicalPlanBuilder::from(t1) .join( t2, @@ -493,6 +494,7 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let fun = Arc::new(ScalarUDF::new_from_impl(AlwaysNullUdf::new())); + // could not eliminate to inner join let plan = LogicalPlanBuilder::from(t1) .join( t2, @@ -543,6 +545,10 @@ mod tests { fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Boolean) } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + panic!() + } } #[test] @@ -551,6 +557,7 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new())); + // could not eliminate to inner join let plan = LogicalPlanBuilder::from(t1) .join( t2, diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index dae72fb474563..6cace7caa03e5 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -619,17 +619,16 @@ impl InferredPredicates { predicate: Expr, replace_map: &HashMap<&Column, &Column>, ) -> Result<()> { + let inferred = replace_col(predicate, replace_map)?; if self.is_inner_join - || matches!( - is_restrict_null_predicate( - &self.join_schema, - predicate.clone(), - replace_map.keys().cloned() - ), - Ok(true) + || is_restrict_null_predicate( + &self.join_schema, + inferred.clone(), + replace_map.values().cloned(), ) + .unwrap_or(false) { - self.predicates.push(replace_col(predicate, replace_map)?); + self.predicates.push(inferred); } Ok(()) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 7b12f56af01b6..f77448332c087 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -20,11 +20,8 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use crate::{OptimizerConfig, OptimizerRule}; - -use crate::analyzer::type_coercion::TypeCoercionRewriter; use arrow::datatypes::DataType; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::{Column, DFSchema, Result, ScalarValue}; +use datafusion_common::{Column, DFSchema, ExprSchema, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::{replace_col, replace_expr_with_null}; use datafusion_expr::{logical_plan::LogicalPlan, Expr, ExprSchemable}; @@ -136,36 +133,39 @@ pub fn is_restrict_null_predicate<'a>( return Ok(true); } - let execution_props = ExecutionProps::default(); - let replace_columns = cols_of_predicate.into_iter().collect(); + let replace_columns = cols_of_predicate + .into_iter() + .map(|col| { + let data_type = input_schema.data_type(col)?; + Ok((col, data_type)) + }) + .collect::>>()?; let replaced_predicate = replace_expr_with_null(predicate, &replace_columns)?; - let coerced_predicate = coerce(replaced_predicate, input_schema)?; + let execution_props = ExecutionProps::default(); let info = SimplifyContext::new(&execution_props) .with_schema(Arc::new(input_schema.clone())); let simplifier = ExprSimplifier::new(info).with_canonicalize(false); - let expr = simplifier.simplify(coerced_predicate)?; + let expr = simplifier.simplify(replaced_predicate)?; + + if matches!(expr.get_type(input_schema)?, DataType::Null) { + return Ok(true); + } let ret = match &expr { Expr::Literal(scalar) if scalar.is_null() => true, Expr::Literal(ScalarValue::Boolean(Some(b))) => !b, - _ if matches!(expr.get_type(input_schema)?, DataType::Null) => true, _ => false, }; Ok(ret) } -fn coerce(expr: Expr, schema: &DFSchema) -> Result { - let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).data() -} - #[cfg(test)] mod tests { use super::*; use arrow::datatypes::{Field, Schema}; - use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator}; + use datafusion_expr::{binary_expr, case, col, in_list, lit, Operator}; #[test] fn expr_is_restrict_null_predicate() -> Result<()> { @@ -173,9 +173,9 @@ mod tests { // a (col("a"), true), // a IS NULL - (is_null(col("a")), false), + (col("a").is_null(), false), // a IS NOT NULL - (Expr::IsNotNull(Box::new(col("a"))), true), + (col("a").is_not_null(), true), // a = NULL ( binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)), @@ -185,25 +185,34 @@ mod tests { (binary_expr(col("a"), Operator::Gt, lit(8i64)), true), // a <= 8 (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true), - // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END + // CASE a WHEN Int32(1) THEN Boolean(true) WHEN Int32(0) THEN Boolean(false) ELSE NULL END ( case(col("a")) - .when(lit(1i64), lit(true)) - .when(lit(0i64), lit(false)) + .when(lit(1i32), lit(true)) + .when(lit(0i32), lit(false)) .otherwise(lit(ScalarValue::Null))?, true, ), - // CASE a WHEN 1 THEN true ELSE false END + // CASE a WHEN Int64(1) THEN Boolean(true) WHEN Int32(0) THEN Boolean(false) ELSE NULL END + // Because of 1 is Int64, this expr can not be simplified. ( case(col("a")) .when(lit(1i64), lit(true)) + .when(lit(0i32), lit(false)) + .otherwise(lit(ScalarValue::Null))?, + false, + ), + // CASE a WHEN 1 THEN true ELSE false END + ( + case(col("a")) + .when(lit(1i32), lit(true)) .otherwise(lit(false))?, true, ), // CASE a WHEN 0 THEN false ELSE true END ( case(col("a")) - .when(lit(0i64), lit(false)) + .when(lit(0i32), lit(false)) .otherwise(lit(true))?, false, ), @@ -211,7 +220,7 @@ mod tests { ( binary_expr( case(col("a")) - .when(lit(0i64), lit(false)) + .when(lit(0i32), lit(false)) .otherwise(lit(true))?, Operator::Or, lit(false), @@ -222,7 +231,7 @@ mod tests { ( binary_expr( case(col("a")) - .when(lit(0i64), lit(true)) + .when(lit(0i32), lit(true)) .otherwise(lit(false))?, Operator::Or, lit(false), @@ -249,8 +258,28 @@ mod tests { in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true), true, ), - // new + // a > b (col("a").gt(col("b")), true), + // a + Int32(10) > b - UInt64(10) + ( + binary_expr(col("a"), Operator::Plus, lit(10i32)).gt(binary_expr( + col("b"), + Operator::Minus, + lit(10u64), + )), + true, + ), + // a + Int64(10) > b - UInt64(10) + // Because of DataType of a column is Int32 and DataType of lit 10 is Int64, + // the expr can not be simplified. + ( + binary_expr(col("a"), Operator::Plus, lit(10i64)).gt(binary_expr( + col("b"), + Operator::Minus, + lit(10u64), + )), + false, + ), ]; let column_a = Column::from_name("a"); @@ -261,11 +290,11 @@ mod tests { let df_schema = DFSchema::try_from(schema)?; for (predicate, expected) in test_cases { - let join_cols_of_predicate = std::iter::once(&column_a); + let cols_of_predicate = std::iter::once(&column_a); let actual = is_restrict_null_predicate( &df_schema, predicate.clone(), - join_cols_of_predicate, + cols_of_predicate, )?; assert_eq!(actual, expected, "{}", predicate); }