Skip to content

Commit

Permalink
modify function
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonLi-cn committed Nov 5, 2024
1 parent 4fb183f commit 62a6dec
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 44 deletions.
21 changes: 12 additions & 9 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@

//! Expression rewriter
use crate::expr::{Alias, Sort, Unnest};
use crate::logical_plan::Projection;
use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
use arrow::datatypes::DataType;
use std::collections::HashMap;
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;

use crate::expr::{Alias, Sort, Unnest};
use crate::logical_plan::Projection;
use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::TableReference;
use datafusion_common::{Column, DFSchema, Result};
use datafusion_common::{ScalarValue, TableReference};

mod order_by;
pub use order_by::rewrite_sort_cols_by_aggs;
Expand Down Expand Up @@ -150,14 +150,17 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul

pub fn replace_expr_with_null(
expr: Expr,
replace_columns: &HashSet<&Column>,
replace_columns: &HashMap<&Column, &DataType>,
) -> Result<Expr> {
expr.transform(|expr| {
Ok({
match &expr {
Expr::Column(c) if replace_columns.contains(c) => {
Transformed::yes(Expr::Literal(ScalarValue::Null))
}
Expr::Column(c) => match replace_columns.get(c) {
Some(data_type) => {
Transformed::yes(Expr::Literal((*data_type).try_into()?))
}
None => Transformed::no(expr),
},
_ => Transformed::no(expr),
}
})
Expand Down
7 changes: 7 additions & 0 deletions datafusion/optimizer/src/eliminate_outer_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ mod tests {
let t2 = test_table_scan_with_name("t2")?;
let fun = Arc::new(ScalarUDF::new_from_impl(DoNothingUdf::new()));

// eliminate to inner join
let plan = LogicalPlanBuilder::from(t1)
.join(
t2,
Expand Down Expand Up @@ -493,6 +494,7 @@ mod tests {
let t2 = test_table_scan_with_name("t2")?;
let fun = Arc::new(ScalarUDF::new_from_impl(AlwaysNullUdf::new()));

// could not eliminate to inner join
let plan = LogicalPlanBuilder::from(t1)
.join(
t2,
Expand Down Expand Up @@ -543,6 +545,10 @@ mod tests {
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
panic!()
}
}

#[test]
Expand All @@ -551,6 +557,7 @@ mod tests {
let t2 = test_table_scan_with_name("t2")?;
let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));

// could not eliminate to inner join
let plan = LogicalPlanBuilder::from(t1)
.join(
t2,
Expand Down
15 changes: 7 additions & 8 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,17 +619,16 @@ impl InferredPredicates {
predicate: Expr,
replace_map: &HashMap<&Column, &Column>,
) -> Result<()> {
let inferred = replace_col(predicate, replace_map)?;
if self.is_inner_join
|| matches!(
is_restrict_null_predicate(
&self.join_schema,
predicate.clone(),
replace_map.keys().cloned()
),
Ok(true)
|| is_restrict_null_predicate(
&self.join_schema,
inferred.clone(),
replace_map.values().cloned(),
)
.unwrap_or(false)
{
self.predicates.push(replace_col(predicate, replace_map)?);
self.predicates.push(inferred);
}

Ok(())
Expand Down
83 changes: 56 additions & 27 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
use std::collections::{BTreeSet, HashMap, HashSet};

use crate::{OptimizerConfig, OptimizerRule};

use crate::analyzer::type_coercion::TypeCoercionRewriter;
use arrow::datatypes::DataType;
use datafusion_common::tree_node::{TransformedResult, TreeNode};
use datafusion_common::{Column, DFSchema, Result, ScalarValue};
use datafusion_common::{Column, DFSchema, ExprSchema, Result, ScalarValue};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr_rewriter::{replace_col, replace_expr_with_null};
use datafusion_expr::{logical_plan::LogicalPlan, Expr, ExprSchemable};
Expand Down Expand Up @@ -136,46 +133,49 @@ pub fn is_restrict_null_predicate<'a>(
return Ok(true);
}

let execution_props = ExecutionProps::default();
let replace_columns = cols_of_predicate.into_iter().collect();
let replace_columns = cols_of_predicate
.into_iter()
.map(|col| {
let data_type = input_schema.data_type(col)?;
Ok((col, data_type))
})
.collect::<Result<HashMap<_, _>>>()?;
let replaced_predicate = replace_expr_with_null(predicate, &replace_columns)?;
let coerced_predicate = coerce(replaced_predicate, input_schema)?;

let execution_props = ExecutionProps::default();
let info = SimplifyContext::new(&execution_props)
.with_schema(Arc::new(input_schema.clone()));
let simplifier = ExprSimplifier::new(info).with_canonicalize(false);
let expr = simplifier.simplify(coerced_predicate)?;
let expr = simplifier.simplify(replaced_predicate)?;

if matches!(expr.get_type(input_schema)?, DataType::Null) {
return Ok(true);
}

let ret = match &expr {
Expr::Literal(scalar) if scalar.is_null() => true,
Expr::Literal(ScalarValue::Boolean(Some(b))) => !b,
_ if matches!(expr.get_type(input_schema)?, DataType::Null) => true,
_ => false,
};

Ok(ret)
}

fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
let mut expr_rewrite = TypeCoercionRewriter { schema };
expr.rewrite(&mut expr_rewrite).data()
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::{Field, Schema};
use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator};
use datafusion_expr::{binary_expr, case, col, in_list, lit, Operator};

#[test]
fn expr_is_restrict_null_predicate() -> Result<()> {
let test_cases = vec![
// a
(col("a"), true),
// a IS NULL
(is_null(col("a")), false),
(col("a").is_null(), false),
// a IS NOT NULL
(Expr::IsNotNull(Box::new(col("a"))), true),
(col("a").is_not_null(), true),
// a = NULL
(
binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)),
Expand All @@ -185,33 +185,42 @@ mod tests {
(binary_expr(col("a"), Operator::Gt, lit(8i64)), true),
// a <= 8
(binary_expr(col("a"), Operator::LtEq, lit(8i32)), true),
// CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END
// CASE a WHEN Int32(1) THEN Boolean(true) WHEN Int32(0) THEN Boolean(false) ELSE NULL END
(
case(col("a"))
.when(lit(1i64), lit(true))
.when(lit(0i64), lit(false))
.when(lit(1i32), lit(true))
.when(lit(0i32), lit(false))
.otherwise(lit(ScalarValue::Null))?,
true,
),
// CASE a WHEN 1 THEN true ELSE false END
// CASE a WHEN Int64(1) THEN Boolean(true) WHEN Int32(0) THEN Boolean(false) ELSE NULL END
// Because of 1 is Int64, this expr can not be simplified.
(
case(col("a"))
.when(lit(1i64), lit(true))
.when(lit(0i32), lit(false))
.otherwise(lit(ScalarValue::Null))?,
false,
),
// CASE a WHEN 1 THEN true ELSE false END
(
case(col("a"))
.when(lit(1i32), lit(true))
.otherwise(lit(false))?,
true,
),
// CASE a WHEN 0 THEN false ELSE true END
(
case(col("a"))
.when(lit(0i64), lit(false))
.when(lit(0i32), lit(false))
.otherwise(lit(true))?,
false,
),
// (CASE a WHEN 0 THEN false ELSE true END) OR false
(
binary_expr(
case(col("a"))
.when(lit(0i64), lit(false))
.when(lit(0i32), lit(false))
.otherwise(lit(true))?,
Operator::Or,
lit(false),
Expand All @@ -222,7 +231,7 @@ mod tests {
(
binary_expr(
case(col("a"))
.when(lit(0i64), lit(true))
.when(lit(0i32), lit(true))
.otherwise(lit(false))?,
Operator::Or,
lit(false),
Expand All @@ -249,8 +258,28 @@ mod tests {
in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true),
true,
),
// new
// a > b
(col("a").gt(col("b")), true),
// a + Int32(10) > b - UInt64(10)
(
binary_expr(col("a"), Operator::Plus, lit(10i32)).gt(binary_expr(
col("b"),
Operator::Minus,
lit(10u64),
)),
true,
),
// a + Int64(10) > b - UInt64(10)
// Because of DataType of a column is Int32 and DataType of lit 10 is Int64,
// the expr can not be simplified.
(
binary_expr(col("a"), Operator::Plus, lit(10i64)).gt(binary_expr(
col("b"),
Operator::Minus,
lit(10u64),
)),
false,
),
];

let column_a = Column::from_name("a");
Expand All @@ -261,11 +290,11 @@ mod tests {
let df_schema = DFSchema::try_from(schema)?;

for (predicate, expected) in test_cases {
let join_cols_of_predicate = std::iter::once(&column_a);
let cols_of_predicate = std::iter::once(&column_a);
let actual = is_restrict_null_predicate(
&df_schema,
predicate.clone(),
join_cols_of_predicate,
cols_of_predicate,
)?;
assert_eq!(actual, expected, "{}", predicate);
}
Expand Down

0 comments on commit 62a6dec

Please sign in to comment.