diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index e90f9c32ee87..b1b889136b35 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -31,6 +31,9 @@ use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; +/// Shared [`PhysicalExpr`]. +pub type PhysicalExprRef = Arc; + /// [`PhysicalExpr`]s represent expressions such as `A + 1` or `CAST(c1 AS int)`. /// /// `PhysicalExpr` knows its type, nullability and can be evaluated directly on diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 9a9f40b6a1d4..a4184845a0de 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -17,12 +17,11 @@ use std::sync::Arc; +use datafusion_common::HashMap; pub(crate) use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +pub use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; use itertools::izip; -/// Shared [`PhysicalExpr`]. -pub type PhysicalExprRef = Arc; - /// This function is similar to the `contains` method of `Vec`. It finds /// whether `expr` is among `physical_exprs`. pub fn physical_exprs_contains( @@ -48,31 +47,24 @@ pub fn physical_exprs_bag_equal( lhs: &[Arc], rhs: &[Arc], ) -> bool { - // TODO: Once we can use `HashMap`s with `Arc`, this - // function should use a `HashMap` to reduce computational complexity. - if lhs.len() == rhs.len() { - let mut rhs_vec = rhs.to_vec(); - for expr in lhs { - if let Some(idx) = rhs_vec.iter().position(|e| expr.eq(e)) { - rhs_vec.swap_remove(idx); - } else { - return false; - } - } - true - } else { - false + let mut multi_set_lhs: HashMap<_, usize> = HashMap::new(); + let mut multi_set_rhs: HashMap<_, usize> = HashMap::new(); + for expr in lhs { + *multi_set_lhs.entry(expr).or_insert(0) += 1; + } + for expr in rhs { + *multi_set_rhs.entry(expr).or_insert(0) += 1; } + multi_set_lhs == multi_set_rhs } #[cfg(test)] mod tests { - use std::sync::Arc; + use super::*; use crate::expressions::{Column, Literal}; use crate::physical_expr::{ physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, - PhysicalExpr, }; use datafusion_common::ScalarValue;