Skip to content

Commit

Permalink
consider volatile function in simply_expression (#13128)
Browse files Browse the repository at this point in the history
* consider volatile function in simply_expression

* refactor and fix bugs

* fix clippy

* refactor

* refactor

* format

* fix clippy

* Resolve logical conflict

* simplify more

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
Lordworms and alamb authored Nov 1, 2024
1 parent b7f4db4 commit 6b76a35
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 8 deletions.
73 changes: 69 additions & 4 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
right,
}) if has_common_conjunction(&left, &right) => {
let lhs: IndexSet<Expr> = iter_conjunction_owned(*left).collect();
let (common, rhs): (Vec<_>, Vec<_>) =
iter_conjunction_owned(*right).partition(|e| lhs.contains(e));
let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right)
.partition(|e| lhs.contains(e) && !e.is_volatile());

let new_rhs = rhs.into_iter().reduce(and);
let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and);
Expand Down Expand Up @@ -1682,8 +1682,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
}

fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
let lhs: HashSet<&Expr> = iter_conjunction(lhs).collect();
iter_conjunction(rhs).any(|e| lhs.contains(&e))
let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
}

// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
Expand Down Expand Up @@ -3978,4 +3978,69 @@ mod tests {
unimplemented!("not needed for tests")
}
}
#[derive(Debug)]
struct VolatileUdf {
signature: Signature,
}

impl VolatileUdf {
pub fn new() -> Self {
Self {
signature: Signature::exact(vec![], Volatility::Volatile),
}
}
}
impl ScalarUDFImpl for VolatileUdf {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"VolatileUdf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int16)
}
}
#[test]
fn test_optimize_volatile_conditions() {
let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
{
let expr = rand
.clone()
.eq(lit(0))
.or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));

assert_eq!(simplify(expr.clone()), expr);
}

{
let expr = col("column1")
.eq(lit(2))
.or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));

assert_eq!(simplify(expr), col("column1").eq(lit(2)));
}

{
let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
"column1",
)
.eq(lit(2))
.and(rand.clone().eq(lit(0))));

assert_eq!(
simplify(expr),
col("column1")
.eq(lit(2))
.and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
);
}
}
}
13 changes: 9 additions & 4 deletions datafusion/optimizer/src/simplify_expressions/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,21 @@ pub static POWS_OF_TEN: [i128; 38] = [

/// returns true if `needle` is found in a chain of search_op
/// expressions. Such as: (A AND B) AND C
pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => {
expr_contains(left, needle, search_op)
|| expr_contains(right, needle, search_op)
expr_contains_inner(left, needle, search_op)
|| expr_contains_inner(right, needle, search_op)
}
_ => expr == needle,
}
}

/// check volatile calls and return if expr contains needle
pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
expr_contains_inner(expr, needle, search_op) && !needle.is_volatile()
}

/// Deletes all 'needles' or remains one 'needle' that are found in a chain of xor
/// expressions. Such as: A ^ (A ^ (B ^ A))
pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr {
Expand Down Expand Up @@ -206,7 +211,7 @@ pub fn is_false(expr: &Expr) -> bool {

/// returns true if `haystack` looks like (needle OP X) or (X OP needle)
pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {
matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()))
matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile())
}

/// returns true if `not_expr` is !`expr` (not)
Expand Down

0 comments on commit 6b76a35

Please sign in to comment.