From bed57df3e8dc04961755da593d345c61d0e1be39 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 15 May 2024 21:52:25 +0300 Subject: [PATCH 01/15] [MINOR]: Move pipeline checker rule to the end (#10502) * Move pipeline checker to last * Update slt --- datafusion/core/src/physical_optimizer/optimizer.rs | 10 +++++----- datafusion/sqllogictest/test_files/explain.slt | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 08cbf68fa617..416985983dfe 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -112,11 +112,6 @@ impl PhysicalOptimizer { // Remove the ancillary output requirement operator since we are done with the planning // phase. Arc::new(OutputRequirements::new_remove_mode()), - // The PipelineChecker rule will reject non-runnable query plans that use - // pipeline-breaking operators on infinite input(s). The rule generates a - // diagnostic error message when this happens. It makes no changes to the - // given query plan; i.e. it only acts as a final gatekeeping rule. - Arc::new(PipelineChecker::new()), // The aggregation limiter will try to find situations where the accumulator count // is not tied to the cardinality, i.e. when the output of the aggregation is passed // into an `order by max(x) limit y`. In this case it will copy the limit value down @@ -129,6 +124,11 @@ impl PhysicalOptimizer { // are not present, the load of executors such as join or union will be // reduced by narrowing their input tables. Arc::new(ProjectionPushdown::new()), + // The PipelineChecker rule will reject non-runnable query plans that use + // pipeline-breaking operators on infinite input(s). The rule generates a + // diagnostic error message when this happens. It makes no changes to the + // given query plan; i.e. it only acts as a final gatekeeping rule. + Arc::new(PipelineChecker::new()), ]; Self::with_rules(rules) diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 3a4ac747ebd6..92c537f975ad 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -252,9 +252,9 @@ physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] @@ -311,9 +311,9 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] @@ -348,9 +348,9 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan 01)GlobalLimitExec: skip=0, fetch=10 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 From ea9c32540870615764f5e8ee1531b1c70dd27eed Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 15 May 2024 14:52:40 -0400 Subject: [PATCH 02/15] Minor: Extract parent/child limit calculation into a function, improve docs (#10501) * Minor: Extract parent/child limit calculation into a function, improve docs * Update datafusion/optimizer/src/push_down_limit.rs Co-authored-by: Oleks V --------- Co-authored-by: Oleks V --- datafusion/optimizer/src/push_down_limit.rs | 116 +++++++++++++------- 1 file changed, 77 insertions(+), 39 deletions(-) diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 1af246fc556d..9190881335af 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -17,6 +17,7 @@ //! [`PushDownLimit`] pushes `LIMIT` earlier in the query plan +use std::cmp::min; use std::sync::Arc; use crate::optimizer::ApplyOrder; @@ -56,47 +57,12 @@ impl OptimizerRule for PushDownLimit { if let LogicalPlan::Limit(child) = &*limit.input { // Merge the Parent Limit and the Child Limit. - - // Case 0: Parent and Child are disjoint. (child_fetch <= skip) - // Before merging: - // |........skip........|---fetch-->| Parent Limit - // |...child_skip...|---child_fetch-->| Child Limit - // After merging: - // |.........(child_skip + skip).........| - // Before merging: - // |...skip...|------------fetch------------>| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---(child_fetch - skip)-->| - - // Case 1: Parent is beyond the range of Child. (skip < child_fetch <= skip + fetch) - // Before merging: - // |...skip...|------------fetch------------>| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---(child_fetch - skip)-->| - - // Case 2: Parent is in the range of Child. (skip + fetch < child_fetch) - // Before merging: - // |...skip...|---fetch-->| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---fetch-->| - let parent_skip = limit.skip; - let new_fetch = match (limit.fetch, child.fetch) { - (Some(fetch), Some(child_fetch)) => { - Some(min(fetch, child_fetch.saturating_sub(parent_skip))) - } - (Some(fetch), None) => Some(fetch), - (None, Some(child_fetch)) => { - Some(child_fetch.saturating_sub(parent_skip)) - } - (None, None) => None, - }; + let (skip, fetch) = + combine_limit(limit.skip, limit.fetch, child.skip, child.fetch); let plan = LogicalPlan::Limit(Limit { - skip: child.skip + parent_skip, - fetch: new_fetch, + skip, + fetch, input: Arc::new((*child.input).clone()), }); return self @@ -217,6 +183,78 @@ impl OptimizerRule for PushDownLimit { } } +/// Combines two limits into a single +/// +/// Returns the combined limit `(skip, fetch)` +/// +/// # Case 0: Parent and Child are disjoint. (`child_fetch <= skip`) +/// +/// ```text +/// Before merging: +/// |........skip........|---fetch-->| Parent Limit +/// |...child_skip...|---child_fetch-->| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |.........(child_skip + skip).........| +/// ``` +/// +/// Before merging: +/// ```text +/// |...skip...|------------fetch------------>| Parent Limit +/// |...child_skip...|-------------child_fetch------------>| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---(child_fetch - skip)-->| +/// ``` +/// +/// # Case 1: Parent is beyond the range of Child. (`skip < child_fetch <= skip + fetch`) +/// +/// Before merging: +/// ```text +/// |...skip...|------------fetch------------>| Parent Limit +/// |...child_skip...|-------------child_fetch------------>| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---(child_fetch - skip)-->| +/// ``` +/// +/// # Case 2: Parent is in the range of Child. (`skip + fetch < child_fetch`) +/// Before merging: +/// ```text +/// |...skip...|---fetch-->| Parent Limit +/// |...child_skip...|-------------child_fetch------------>| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---fetch-->| +/// ``` +fn combine_limit( + parent_skip: usize, + parent_fetch: Option, + child_skip: usize, + child_fetch: Option, +) -> (usize, Option) { + let combined_skip = child_skip.saturating_add(parent_skip); + + let combined_fetch = match (parent_fetch, child_fetch) { + (Some(parent_fetch), Some(child_fetch)) => { + Some(min(parent_fetch, child_fetch.saturating_sub(parent_skip))) + } + (Some(parent_fetch), None) => Some(parent_fetch), + (None, Some(child_fetch)) => Some(child_fetch.saturating_sub(parent_skip)), + (None, None) => None, + }; + + (combined_skip, combined_fetch) +} + fn push_down_join(join: &Join, limit: usize) -> Option { use JoinType::*; From 8199e9e6601d91320e395b43ba3a005ae7ba4816 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Thu, 16 May 2024 02:53:36 +0800 Subject: [PATCH 03/15] Fix window expr deserialization (#10506) * Fix window expr deserialization * Improve naming and doc * Update window test --- .../core/tests/fuzz_cases/window_fuzz.rs | 34 ++----------------- datafusion/physical-plan/src/windows/mod.rs | 26 ++++++++++++++ .../proto/src/physical_plan/from_proto.rs | 12 ++++--- .../tests/cases/roundtrip_physical_plan.rs | 3 +- 4 files changed, 38 insertions(+), 37 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 2514324a9541..fe0c408dc114 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,11 +22,10 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{Field, Schema}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, WindowAggExec, + create_window_expr, schema_add_window_field, BoundedWindowAggExec, WindowAggExec, }; use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; use datafusion::physical_plan::{collect, InputOrderMode}; @@ -40,7 +39,6 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use itertools::Itertools; use test_utils::add_empty_batches; use hashbrown::HashMap; @@ -276,7 +274,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { }; let extended_schema = - schema_add_window_fields(&args, &schema, &window_fn, fn_name)?; + schema_add_window_field(&args, &schema, &window_fn, fn_name)?; let window_expr = create_window_expr( &window_fn, @@ -683,7 +681,7 @@ async fn run_window_test( exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; } - let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?; + let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; let usual_window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( @@ -754,32 +752,6 @@ async fn run_window_test( Ok(()) } -// The planner has fully updated schema before calling the `create_window_expr` -// Replicate the same for this test -fn schema_add_window_fields( - args: &[Arc], - schema: &Arc, - window_fn: &WindowFunctionDefinition, - fn_name: &str, -) -> Result> { - let data_types = args - .iter() - .map(|e| e.clone().as_ref().data_type(schema)) - .collect::>>()?; - let window_expr_return_type = window_fn.return_type(&data_types)?; - let mut window_fields = schema - .fields() - .iter() - .map(|f| f.as_ref().clone()) - .collect_vec(); - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - true, - )]); - Ok(Arc::new(Schema::new(window_fields))) -} - /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d1223f78808c..42c630741cc9 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -42,6 +42,7 @@ use datafusion_physical_expr::{ window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr}, AggregateExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, }; +use itertools::Itertools; mod bounded_window_agg_exec; mod window_agg_exec; @@ -52,6 +53,31 @@ pub use datafusion_physical_expr::window::{ }; pub use window_agg_exec::WindowAggExec; +/// Build field from window function and add it into schema +pub fn schema_add_window_field( + args: &[Arc], + schema: &Schema, + window_fn: &WindowFunctionDefinition, + fn_name: &str, +) -> Result> { + let data_types = args + .iter() + .map(|e| e.clone().as_ref().data_type(schema)) + .collect::>>()?; + let window_expr_return_type = window_fn.return_type(&data_types)?; + let mut window_fields = schema + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect_vec(); + window_fields.extend_from_slice(&[Field::new( + fn_name, + window_expr_return_type, + false, + )]); + Ok(Arc::new(Schema::new(window_fields))) +} + /// Create a physical expression for window function #[allow(clippy::too_many_arguments)] pub fn create_window_expr( diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index c907e991fb86..a290f30586ce 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -40,7 +40,7 @@ use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, }; -use datafusion::physical_plan::windows::create_window_expr; +use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion::physical_plan::{ ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; @@ -155,14 +155,18 @@ pub fn parse_physical_window_expr( ) })?; + let fun: WindowFunctionDefinition = convert_required!(proto.window_function)?; + let name = proto.name.clone(); + let extended_schema = + schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?; create_window_expr( - &convert_required!(proto.window_function)?, - proto.name.clone(), + &fun, + name, &window_node_expr, &partition_by, &order_by, Arc::new(window_frame), - input_schema, + &extended_schema, false, ) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 30a28081edff..dd8e450d3165 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -253,8 +253,7 @@ fn roundtrip_nested_loop_join() -> Result<()> { fn roundtrip_window() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); - let field_c = Field::new("FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); let window_frame = WindowFrame::new_bounds( datafusion_expr::WindowFrameUnits::Range, From a331b36a245c8c31f28b7b08af55cfd01c5d537a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 15 May 2024 15:24:30 -0400 Subject: [PATCH 04/15] Update substrait requirement from 0.32.0 to 0.33.3 (#10516) Updates the requirements on [substrait](https://github.com/substrait-io/substrait-rs) to permit the latest version. - [Release notes](https://github.com/substrait-io/substrait-rs/releases) - [Changelog](https://github.com/substrait-io/substrait-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/substrait-io/substrait-rs/compare/v0.32.0...v0.33.3) --- updated-dependencies: - dependency-name: substrait dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/substrait/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index dce8ce10b587..e4be6e68ff16 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -39,7 +39,7 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.32.0" +substrait = "0.33.3" [dev-dependencies] tokio = { workspace = true } From c312ffe7d954563888a303beb8796848d20ff7c6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 15 May 2024 15:58:15 -0400 Subject: [PATCH 05/15] Stop copying LogicalPlan and Exprs in `TypeCoercion` (10% faster planning) (#10356) * Add `LogicalPlan::recompute_schema` for handling rewrite passes * Stop copying LogicalPlan and Exprs in `TypeCoercion` * Apply suggestions from code review Co-authored-by: Oleks V --------- Co-authored-by: Oleks V --- .../optimizer/src/analyzer/type_coercion.rs | 125 ++++++++++++------ 1 file changed, 88 insertions(+), 37 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 60b81aff9aaa..0f1f3ba7e729 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, @@ -31,8 +31,8 @@ use datafusion_expr::expr::{ self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, WindowFunction, }; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ comparison_coercion, get_input_types, like_coercion, @@ -52,6 +52,7 @@ use datafusion_expr::{ }; use crate::analyzer::AnalyzerRule; +use crate::utils::NamePreserver; #[derive(Default)] pub struct TypeCoercion {} @@ -68,26 +69,28 @@ impl AnalyzerRule for TypeCoercion { } fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - analyze_internal(&DFSchema::empty(), &plan) + let empty_schema = DFSchema::empty(); + + let transformed_plan = plan + .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))? + .data; + + Ok(transformed_plan) } } +/// use the external schema to handle the correlated subqueries case +/// +/// Assumes that children have already been optimized fn analyze_internal( - // use the external schema to handle the correlated subqueries case external_schema: &DFSchema, - plan: &LogicalPlan, -) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| analyze_internal(external_schema, p)) - .collect::>>()?; + plan: LogicalPlan, +) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); + let mut schema = merge_schema(plan.inputs()); - if let LogicalPlan::TableScan(ts) = plan { + if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -100,25 +103,75 @@ fn analyze_internal( // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) schema.merge(external_schema); - let mut expr_rewrite = TypeCoercionRewriter { schema: &schema }; - - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure aggregate names don't change: - // https://github.com/apache/datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; - - plan.with_new_exprs(new_expr, new_inputs) + let mut expr_rewrite = TypeCoercionRewriter::new(&schema); + + let name_preserver = NamePreserver::new(&plan); + // apply coercion rewrite all expressions in the plan individually + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + expr.rewrite(&mut expr_rewrite)? + .map_data(|expr| original_name.restore(expr)) + })? + // coerce join expressions specially + .map_data(|plan| expr_rewrite.coerce_joins(plan))? + // recompute the schema after the expressions have been rewritten as the types may have changed + .map_data(|plan| plan.recompute_schema()) } pub(crate) struct TypeCoercionRewriter<'a> { pub(crate) schema: &'a DFSchema, } +impl<'a> TypeCoercionRewriter<'a> { + fn new(schema: &'a DFSchema) -> Self { + Self { schema } + } + + /// Coerce join equality expressions + /// + /// Joins must be treated specially as their equality expressions are stored + /// as a parallel list of left and right expressions, rather than a single + /// equality expression + /// + /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored + /// as a list of `(t1.a, t2.b), (t1.x, t2.y)` + fn coerce_joins(&mut self, plan: LogicalPlan) -> Result { + let LogicalPlan::Join(mut join) = plan else { + return Ok(plan); + }; + + join.on = join + .on + .into_iter() + .map(|(lhs, rhs)| { + // coerce the arguments as though they were a single binary equality + // expression + let (lhs, rhs) = self.coerce_binary_op(lhs, Operator::Eq, rhs)?; + Ok((lhs, rhs)) + }) + .collect::>>()?; + + Ok(LogicalPlan::Join(join)) + } + + fn coerce_binary_op( + &self, + left: Expr, + op: Operator, + right: Expr, + ) -> Result<(Expr, Expr)> { + let (left_type, right_type) = get_input_types( + &left.get_type(self.schema)?, + &op, + &right.get_type(self.schema)?, + )?; + Ok(( + left.cast_to(&left_type, self.schema)?, + right.cast_to(&right_type, self.schema)?, + )) + } +} + impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { type Node = Expr; @@ -131,14 +184,15 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { subquery, outer_ref_columns, }) => { - let new_plan = analyze_internal(self.schema, &subquery)?; + let new_plan = analyze_internal(self.schema, unwrap_arc(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, }))) } Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(self.schema, &subquery.subquery)?; + let new_plan = + analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data; Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), @@ -152,7 +206,8 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { subquery, negated, }) => { - let new_plan = analyze_internal(self.schema, &subquery.subquery)?; + let new_plan = + analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data; let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( @@ -221,15 +276,11 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let (left_type, right_type) = get_input_types( - &left.get_type(self.schema)?, - &op, - &right.get_type(self.schema)?, - )?; + let (left, right) = self.coerce_binary_op(*left, op, *right)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, self.schema)?), + Box::new(left), op, - Box::new(right.cast_to(&right_type, self.schema)?), + Box::new(right), )))) } Expr::Between(Between { From eddec8e78865c0f17bd089af641492b1d8e8a411 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 16 May 2024 07:43:17 +0800 Subject: [PATCH 06/15] Implement unparse `IS_NULL` to String and enhance the tests (#10529) * implement unparse is_null and add test * format the code --- datafusion/sql/src/unparser/expr.rs | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 804fa6d306b4..23e3d9ab3594 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -391,7 +391,9 @@ impl Unparser<'_> { Expr::ScalarVariable(_, _) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") } - Expr::IsNull(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + Expr::IsNull(expr) => { + Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql(expr)?))) + } Expr::IsNotFalse(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), Expr::GetIndexedField(_) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") @@ -863,7 +865,7 @@ mod tests { use datafusion_expr::{ case, col, exists, expr::{AggregateFunction, AggregateFunctionDefinition}, - lit, not, not_exists, table_scan, wildcard, ColumnarValue, ScalarUDF, + lit, not, not_exists, table_scan, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; @@ -933,6 +935,14 @@ mod tests { .otherwise(lit(ScalarValue::Null))?, r#"CASE "a" WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END"#, ), + ( + when(col("a").is_null(), lit(true)).otherwise(lit(false))?, + r#"CASE WHEN "a" IS NULL THEN true ELSE false END"#, + ), + ( + when(col("a").is_not_null(), lit(true)).otherwise(lit(false))?, + r#"CASE WHEN "a" IS NOT NULL THEN true ELSE false END"#, + ), ( Expr::Cast(Cast { expr: Box::new(col("a")), @@ -959,6 +969,18 @@ mod tests { ScalarUDF::new_from_impl(DummyUDF::new()).call(vec![col("a"), col("b")]), r#"dummy_udf("a", "b")"#, ), + ( + ScalarUDF::new_from_impl(DummyUDF::new()) + .call(vec![col("a"), col("b")]) + .is_null(), + r#"dummy_udf("a", "b") IS NULL"#, + ), + ( + ScalarUDF::new_from_impl(DummyUDF::new()) + .call(vec![col("a"), col("b")]) + .is_not_null(), + r#"dummy_udf("a", "b") IS NOT NULL"#, + ), ( Expr::Like(Like { negated: true, @@ -1081,6 +1103,7 @@ mod tests { r#"COUNT(*) OVER (ORDER BY "a" DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, ), (col("a").is_not_null(), r#""a" IS NOT NULL"#), + (col("a").is_null(), r#""a" IS NULL"#), ( (col("a") + col("b")).gt(lit(4)).is_true(), r#"(("a" + "b") > 4) IS TRUE"#, From 626c6bc8bf9b10aaf416b7494ae2c31c14cec5ce Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 16 May 2024 07:56:31 +0800 Subject: [PATCH 07/15] support merge batch for distinct array aggregate (#10526) Signed-off-by: jayzhan211 --- .../src/aggregate/array_agg_distinct.rs | 11 ++- .../sqllogictest/test_files/aggregate.slt | 67 +++++++++++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index b8671c39a943..244a44acdcb5 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -153,12 +153,11 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - let array = &states[0]; - - assert_eq!(array.len(), 1, "state array should only include 1 row!"); - // Unwrap outer ListArray then do update batch - let inner_array = array.as_list::().value(0); - self.update_batch(&[inner_array]) + states[0] + .as_list::() + .iter() + .flatten() + .try_for_each(|val| self.update_batch(&[val])) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 40d66f9b52ce..78421d0b6431 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -198,6 +198,73 @@ statement error This feature is not implemented: LIMIT not supported in ARRAY_AG SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 +# Test distinct aggregate function with merge batch +query II +with A as ( + select 1 as id, 2 as foo + UNION ALL + select 1, null + UNION ALL + select 1, null + UNION ALL + select 1, 3 + UNION ALL + select 1, 2 + ---- The order is non-deterministic, verify with length +) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; +---- +3 1 + +# It has only AggregateExec with FinalPartitioned mode, so `merge_batch` is used +# If the plan is changed, whether the `merge_batch` is used should be verified to ensure the test coverage +query TT +explain with A as ( + select 1 as id, 2 as foo + UNION ALL + select 1, null + UNION ALL + select 1, null + UNION ALL + select 1, 3 + UNION ALL + select 1, 2 +) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; +---- +logical_plan +01)Projection: array_length(ARRAY_AGG(DISTINCT a.foo)), SUM(DISTINCT Int64(1)) +02)--Aggregate: groupBy=[[a.id]], aggr=[[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))]] +03)----SubqueryAlias: a +04)------SubqueryAlias: a +05)--------Union +06)----------Projection: Int64(1) AS id, Int64(2) AS foo +07)------------EmptyRelation +08)----------Projection: Int64(1) AS id, Int64(NULL) AS foo +09)------------EmptyRelation +10)----------Projection: Int64(1) AS id, Int64(NULL) AS foo +11)------------EmptyRelation +12)----------Projection: Int64(1) AS id, Int64(3) AS foo +13)------------EmptyRelation +14)----------Projection: Int64(1) AS id, Int64(2) AS foo +15)------------EmptyRelation +physical_plan +01)ProjectionExec: expr=[array_length(ARRAY_AGG(DISTINCT a.foo)@1) as array_length(ARRAY_AGG(DISTINCT a.foo)), SUM(DISTINCT Int64(1))@2 as SUM(DISTINCT Int64(1))] +02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))] +06)----------UnionExec +07)------------ProjectionExec: expr=[1 as id, 2 as foo] +08)--------------PlaceholderRowExec +09)------------ProjectionExec: expr=[1 as id, NULL as foo] +10)--------------PlaceholderRowExec +11)------------ProjectionExec: expr=[1 as id, NULL as foo] +12)--------------PlaceholderRowExec +13)------------ProjectionExec: expr=[1 as id, 3 as foo] +14)--------------PlaceholderRowExec +15)------------ProjectionExec: expr=[1 as id, 2 as foo] +16)--------------PlaceholderRowExec + + # FIX: custom absolute values # csv_query_avg_multi_batch From 5a8348f7111b2b0d39f2bd3fd1b1534338113b9f Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 16 May 2024 08:30:58 +0800 Subject: [PATCH 08/15] UDAF: Extend more args to `state_fields` and `groups_accumulator_supported` and introduce `ReversedUDAF` (#10525) * extends args Signed-off-by: jayzhan211 * reuse accumulator args Signed-off-by: jayzhan211 * fix example Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 15 ++--- .../examples/simplify_udaf_expression.rs | 11 +--- .../user_defined/user_defined_aggregates.rs | 2 +- datafusion/expr/src/expr_fn.rs | 8 +-- datafusion/expr/src/function.rs | 51 ++++++++++++----- datafusion/expr/src/udaf.rs | 57 +++++++++++-------- .../functions-aggregate/src/covariance.rs | 22 +++---- .../functions-aggregate/src/first_last.rs | 15 ++--- .../simplify_expressions/expr_simplifier.rs | 6 +- .../physical-expr-common/src/aggregate/mod.rs | 43 ++++++++++---- 10 files changed, 128 insertions(+), 102 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 342a23b6e73d..cf284472212f 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -31,8 +31,8 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{cast::as_float64_array, ScalarValue}; use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl, - GroupsAccumulator, Signature, + function::{AccumulatorArgs, StateFieldsArgs}, + Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, }; /// This example shows how to use the full AggregateUDFImpl API to implement a user @@ -92,21 +92,16 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields( - &self, - _name: &str, - value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", value_type, true), + Field::new("prod", args.return_type.clone(), true), Field::new("n", DataType::UInt32, true), ]) } /// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator` /// which is used for cases when there are grouping columns in the query - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index 92deb20272e4..08b6bcab0190 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -17,7 +17,7 @@ use arrow_schema::{Field, Schema}; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use datafusion_expr::function::AggregateFunctionSimplification; +use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs}; use datafusion_expr::simplify::SimplifyInfo; use std::{any::Any, sync::Arc}; @@ -70,16 +70,11 @@ impl AggregateUDFImpl for BetterAvgUdaf { unimplemented!("should not be invoked") } - fn state_fields( - &self, - _name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!("should not be invoked") } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 8f02fb30b013..d199f04ba781 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -725,7 +725,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator { panic!("accumulator shouldn't invoke"); } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1d976a12cc4f..64763a973687 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -23,6 +23,7 @@ use crate::expr::{ }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, + StateFieldsArgs, }; use crate::{ aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, @@ -690,12 +691,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { (self.accumulator)(acc_args) } - fn state_fields( - &self, - _name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(self.state_fields.clone()) } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 4e4d77924a9d..714cfa1af671 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -19,7 +19,7 @@ use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; use std::sync::Arc; @@ -41,11 +41,14 @@ pub type ReturnTypeFunction = /// [`AccumulatorArgs`] contains information about how an aggregate /// function was called, including the types of its arguments and any optional /// ordering expressions. +#[derive(Debug)] pub struct AccumulatorArgs<'a> { /// The return type of the aggregate function. pub data_type: &'a DataType, + /// The schema of the input arguments pub schema: &'a Schema, + /// Whether to ignore nulls. /// /// SQL allows the user to specify `IGNORE NULLS`, for example: @@ -66,22 +69,40 @@ pub struct AccumulatorArgs<'a> { /// /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. pub sort_exprs: &'a [Expr], + + /// Whether the aggregate function is distinct. + /// + /// ```sql + /// SELECT COUNT(DISTINCT column1) FROM t; + /// ``` + pub is_distinct: bool, + + /// The input type of the aggregate function. + pub input_type: &'a DataType, + + /// The number of arguments the aggregate function takes. + pub args_num: usize, } -impl<'a> AccumulatorArgs<'a> { - pub fn new( - data_type: &'a DataType, - schema: &'a Schema, - ignore_nulls: bool, - sort_exprs: &'a [Expr], - ) -> Self { - Self { - data_type, - schema, - ignore_nulls, - sort_exprs, - } - } +/// [`StateFieldsArgs`] contains information about the fields that an +/// aggregate function's accumulator should have. Used for [`AggregateUDFImpl::state_fields`]. +/// +/// [`AggregateUDFImpl::state_fields`]: crate::udaf::AggregateUDFImpl::state_fields +pub struct StateFieldsArgs<'a> { + /// The name of the aggregate function. + pub name: &'a str, + + /// The input type of the aggregate function. + pub input_type: &'a DataType, + + /// The return type of the aggregate function. + pub return_type: &'a DataType, + + /// The ordering fields of the aggregate function. + pub ordering_fields: &'a [Field], + + /// Whether the aggregate function is distinct. + pub is_distinct: bool, } /// Factory that returns an accumulator for the given aggregate function. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 95121d78e7aa..4fd8d51679f0 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,7 +17,9 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::function::{AccumulatorArgs, AggregateFunctionSimplification}; +use crate::function::{ + AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, +}; use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::{Accumulator, Expr}; @@ -177,18 +179,13 @@ impl AggregateUDF { /// for more details. /// /// This is used to support multi-phase aggregations - pub fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { - self.inner.state_fields(name, value_type, ordering_fields) + pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.inner.state_fields(args) } /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. - pub fn groups_accumulator_supported(&self) -> bool { - self.inner.groups_accumulator_supported() + pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + self.inner.groups_accumulator_supported(args) } /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details. @@ -232,7 +229,7 @@ where /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; -/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs}; +/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; /// # use arrow::datatypes::Schema; /// # use arrow::datatypes::Field; /// #[derive(Debug, Clone)] @@ -261,9 +258,9 @@ where /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } -/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec) -> Result> { +/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// Ok(vec![ -/// Field::new("value", value_type, true), +/// Field::new("value", args.return_type.clone(), true), /// Field::new("ordering", DataType::UInt32, true) /// ]) /// } @@ -319,19 +316,17 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// The name of the fields must be unique within the query and thus should /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { - let value_fields = vec![Field::new( - format_state_name(name, "value"), - value_type, + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let fields = vec![Field::new( + format_state_name(args.name, "value"), + args.return_type.clone(), true, )]; - Ok(value_fields.into_iter().chain(ordering_fields).collect()) + Ok(fields + .into_iter() + .chain(args.ordering_fields.to_vec()) + .collect()) } /// If the aggregate expression has a specialized @@ -344,7 +339,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// `Self::accumulator` for certain queries, such as when this aggregate is /// used as a window function or when there no GROUP BY columns in the /// query. - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { false } @@ -389,6 +384,20 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn simplify(&self) -> Option { None } + + /// Returns the reverse expression of the aggregate function. + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::NotSupported + } +} + +pub enum ReversedUDAF { + /// The expression is the same as the original expression, like SUM, COUNT + Identical, + /// The expression does not support reverse calculation, like ArrayAgg + NotSupported, + /// The expression is different from the original expression + Reversed(Arc), } /// AggregateUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index 1210e1529dbb..6f03b256fd9f 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -30,8 +30,10 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::{ - function::AccumulatorArgs, type_coercion::aggregates::NUMERICS, - utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, + function::{AccumulatorArgs, StateFieldsArgs}, + type_coercion::aggregates::NUMERICS, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Signature, Volatility, }; use datafusion_physical_expr_common::aggregate::stats::StatsType; @@ -101,12 +103,8 @@ impl AggregateUDFImpl for CovarianceSample { Ok(DataType::Float64) } - fn state_fields( - &self, - name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean1"), DataType::Float64, true), @@ -176,12 +174,8 @@ impl AggregateUDFImpl for CovariancePopulation { Ok(DataType::Float64) } - fn state_fields( - &self, - name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean1"), DataType::Float64, true), diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index e3b685e90376..5d3d48344014 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -147,18 +147,13 @@ impl AggregateUDFImpl for FirstValue { .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( - format_state_name(name, "first_value"), - value_type, + format_state_name(args.name, "first_value"), + args.return_type.clone(), true, )]; - fields.extend(ordering_fields); + fields.extend(args.ordering_fields.to_vec()); fields.push(Field::new("is_set", DataType::Boolean, true)); Ok(fields) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 55052542a8bf..455d659fb25e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1759,7 +1759,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { mod tests { use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ - function::AggregateFunctionSimplification, interval_arithmetic::Interval, *, + function::{AccumulatorArgs, AggregateFunctionSimplification}, + interval_arithmetic::Interval, + *, }; use std::{ collections::HashMap, @@ -3783,7 +3785,7 @@ mod tests { unimplemented!("not needed for tests") } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { unimplemented!("not needed for testing") } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 05641b373b72..da24f335b2f8 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -20,6 +20,7 @@ pub mod utils; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, @@ -74,6 +75,7 @@ pub fn create_aggregate_expr( ignore_nulls, ordering_fields, is_distinct, + input_type: input_exprs_types[0].clone(), })) } @@ -166,6 +168,7 @@ pub struct AggregateFunctionExpr { ignore_nulls: bool, ordering_fields: Vec, is_distinct: bool, + input_type: DataType, } impl AggregateFunctionExpr { @@ -191,11 +194,15 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - self.fun.state_fields( - self.name(), - self.data_type.clone(), - self.ordering_fields.clone(), - ) + let args = StateFieldsArgs { + name: &self.name, + input_type: &self.input_type, + return_type: &self.data_type, + ordering_fields: &self.ordering_fields, + is_distinct: self.is_distinct, + }; + + self.fun.state_fields(args) } fn field(&self) -> Result { @@ -203,12 +210,15 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - let acc_args = AccumulatorArgs::new( - &self.data_type, - &self.schema, - self.ignore_nulls, - &self.sort_exprs, - ); + let acc_args = AccumulatorArgs { + data_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + sort_exprs: &self.sort_exprs, + is_distinct: self.is_distinct, + input_type: &self.input_type, + args_num: self.args.len(), + }; self.fun.accumulator(acc_args) } @@ -273,7 +283,16 @@ impl AggregateExpr for AggregateFunctionExpr { } fn groups_accumulator_supported(&self) -> bool { - self.fun.groups_accumulator_supported() + let args = AccumulatorArgs { + data_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + sort_exprs: &self.sort_exprs, + is_distinct: self.is_distinct, + input_type: &self.input_type, + args_num: self.args.len(), + }; + self.fun.groups_accumulator_supported(args) } fn create_groups_accumulator(&self) -> Result> { From 357987f8d061ce0cdc608d083decf035d929d899 Mon Sep 17 00:00:00 2001 From: Xin Li <33629085+xinlifoobar@users.noreply.github.com> Date: Thu, 16 May 2024 15:39:04 +0800 Subject: [PATCH 09/15] Move min_max unit tests to slt (#10539) * Move min_max unit tests to slt * Enrich comments --- .../physical-expr/src/aggregate/min_max.rs | 506 ------------------ .../physical-expr/src/expressions/mod.rs | 20 - .../sqllogictest/test_files/aggregate.slt | 411 ++++++++++++++ 3 files changed, 411 insertions(+), 526 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 95ae3207462e..50bd24c487bf 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -1103,509 +1103,3 @@ impl Accumulator for SlidingMinAccumulator { std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use crate::expressions::tests::{aggregate, aggregate_new}; - use crate::{generic_test_op, generic_test_op_new}; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use datafusion_common::ScalarValue::Decimal128; - - #[test] - fn min_decimal() -> Result<()> { - // min - let left = ScalarValue::Decimal128(Some(123), 10, 2); - let right = ScalarValue::Decimal128(Some(124), 10, 2); - let result = min(&left, &right)?; - assert_eq!(result, left); - - // min batch - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - - let result = min_batch(&array)?; - assert_eq!(result, ScalarValue::Decimal128(Some(1), 10, 0)); - - // min batch without values - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(0) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - let result = min_batch(&array)?; - assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); - - // min batch with agg - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(Some(1), 10, 0) - ) - } - - #[test] - fn min_decimal_all_nulls() -> Result<()> { - // min batch all nulls - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(6) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(None, 10, 0) - ) - } - - #[test] - fn min_decimal_with_nulls() -> Result<()> { - // min batch with nulls - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(Some(1), 10, 0) - ) - } - - #[test] - fn max_decimal() -> Result<()> { - // max - let left = ScalarValue::Decimal128(Some(123), 10, 2); - let right = ScalarValue::Decimal128(Some(124), 10, 2); - let result = max(&left, &right)?; - assert_eq!(result, right); - - let right = ScalarValue::Decimal128(Some(124), 10, 3); - let result = max(&left, &right); - let err_msg = format!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3)) - ); - let expect = DataFusionError::Internal(err_msg); - assert!(expect - .strip_backtrace() - .starts_with(&result.unwrap_err().strip_backtrace())); - - // max batch - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 5)?, - ); - let result = max_batch(&array)?; - assert_eq!(result, ScalarValue::Decimal128(Some(5), 10, 5)); - - // max batch without values - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(0) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - let result = max_batch(&array)?; - assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); - - // max batch with agg - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Max, - ScalarValue::Decimal128(Some(5), 10, 0) - ) - } - - #[test] - fn max_decimal_with_nulls() -> Result<()> { - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Max, - ScalarValue::Decimal128(Some(5), 10, 0) - ) - } - - #[test] - fn max_decimal_all_nulls() -> Result<()> { - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(6) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(None, 10, 0) - ) - } - - #[test] - fn max_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Max, ScalarValue::from(5i32)) - } - - #[test] - fn min_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Min, ScalarValue::from(1i32)) - } - - #[test] - fn max_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!(a, DataType::Utf8, Max, ScalarValue::from("d")) - } - - #[test] - fn max_large_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::LargeUtf8, - Max, - ScalarValue::LargeUtf8(Some("d".to_string())) - ) - } - - #[test] - fn min_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!(a, DataType::Utf8, Min, ScalarValue::from("a")) - } - - #[test] - fn min_large_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::LargeUtf8, - Min, - ScalarValue::LargeUtf8(Some("a".to_string())) - ) - } - - #[test] - fn max_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, Max, ScalarValue::from(5i32)) - } - - #[test] - fn min_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, Min, ScalarValue::from(1i32)) - } - - #[test] - fn max_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Max, ScalarValue::Int32(None)) - } - - #[test] - fn min_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Min, ScalarValue::Int32(None)) - } - - #[test] - fn max_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Max, ScalarValue::from(5_u32)) - } - - #[test] - fn min_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Min, ScalarValue::from(1u32)) - } - - #[test] - fn max_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Max, ScalarValue::from(5_f32)) - } - - #[test] - fn min_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Min, ScalarValue::from(1_f32)) - } - - #[test] - fn max_f64() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Max, ScalarValue::from(5_f64)) - } - - #[test] - fn min_f64() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Min, ScalarValue::from(1_f64)) - } - - #[test] - fn min_date32() -> Result<()> { - let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date32, Min, ScalarValue::Date32(Some(1))) - } - - #[test] - fn min_date64() -> Result<()> { - let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date64, Min, ScalarValue::Date64(Some(1))) - } - - #[test] - fn max_date32() -> Result<()> { - let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date32, Max, ScalarValue::Date32(Some(5))) - } - - #[test] - fn max_date64() -> Result<()> { - let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date64, Max, ScalarValue::Date64(Some(5))) - } - - #[test] - fn min_time32second() -> Result<()> { - let a: ArrayRef = Arc::new(Time32SecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Second), - Min, - ScalarValue::Time32Second(Some(1)) - ) - } - - #[test] - fn max_time32second() -> Result<()> { - let a: ArrayRef = Arc::new(Time32SecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Second), - Max, - ScalarValue::Time32Second(Some(5)) - ) - } - - #[test] - fn min_time32millisecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time32MillisecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Millisecond), - Min, - ScalarValue::Time32Millisecond(Some(1)) - ) - } - - #[test] - fn max_time32millisecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time32MillisecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Millisecond), - Max, - ScalarValue::Time32Millisecond(Some(5)) - ) - } - - #[test] - fn min_time64microsecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64MicrosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Microsecond), - Min, - ScalarValue::Time64Microsecond(Some(1)) - ) - } - - #[test] - fn max_time64microsecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64MicrosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Microsecond), - Max, - ScalarValue::Time64Microsecond(Some(5)) - ) - } - - #[test] - fn min_time64nanosecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64NanosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Nanosecond), - Min, - ScalarValue::Time64Nanosecond(Some(1)) - ) - } - - #[test] - fn max_time64nanosecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64NanosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Nanosecond), - Max, - ScalarValue::Time64Nanosecond(Some(5)) - ) - } - - #[test] - fn max_new_timestamp_micro() -> Result<()> { - let dt = DataType::Timestamp(TimeUnit::Microsecond, None); - let actual = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]) - .with_data_type(dt.clone()); - let expected: ArrayRef = - Arc::new(TimestampMicrosecondArray::from(vec![5]).with_data_type(dt.clone())); - generic_test_op_new!(Arc::new(actual), dt.clone(), Max, &expected) - } - - #[test] - fn max_new_timestamp_micro_with_tz() -> Result<()> { - let dt = DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())); - let actual = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]) - .with_data_type(dt.clone()); - let expected: ArrayRef = - Arc::new(TimestampMicrosecondArray::from(vec![5]).with_data_type(dt.clone())); - generic_test_op_new!(Arc::new(actual), dt.clone(), Max, &expected) - } - - #[test] - fn max_bool() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, false])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, true])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, true, false])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, false, true])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(Vec::::new())); - generic_test_op!( - a, - DataType::Boolean, - Max, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![None as Option])); - generic_test_op!( - a, - DataType::Boolean, - Max, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = - Arc::new(BooleanArray::from(vec![None, Some(true), Some(false)])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - Ok(()) - } - - #[test] - fn min_bool() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, false])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, true])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, true, false])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, false, true])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(Vec::::new())); - generic_test_op!( - a, - DataType::Boolean, - Min, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![None as Option])); - generic_test_op!( - a, - DataType::Boolean, - Min, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = - Arc::new(BooleanArray::from(vec![None, Some(true), Some(false)])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c16b609e2375..980297b8b433 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -100,9 +100,7 @@ pub(crate) mod tests { use crate::AggregateExpr; use arrow::record_batch::RecordBatch; - use arrow_array::ArrayRef; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::EmitTo; /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the /// result. @@ -250,22 +248,4 @@ pub(crate) mod tests { accum.update_batch(&values)?; accum.evaluate() } - - pub fn aggregate_new( - batch: &RecordBatch, - agg: Arc, - ) -> Result { - let mut accum = agg.create_groups_accumulator()?; - let expr = agg.expressions(); - let values = expr - .iter() - .map(|e| { - e.evaluate(batch) - .and_then(|v| v.into_array(batch.num_rows())) - }) - .collect::>>()?; - let indices = vec![0; batch.num_rows()]; - accum.update_batch(&values, &indices, None, 1)?; - accum.evaluate(EmitTo::All) - } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 78421d0b6431..983f8a085ba9 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2547,6 +2547,417 @@ Select bit_xor(DISTINCT c), arrow_typeof(bit_xor(DISTINCT c)) from t; statement ok drop table t; +################# +# Min_Max Begin # +################# +# min_decimal, max_decimal +statement ok +CREATE TABLE decimals (value DECIMAL(10, 2)); + +statement ok +INSERT INTO decimals VALUES (123.0001), (124.00); + +query RR +SELECT MIN(value), MAX(value) FROM decimals; +---- +123 124 + +statement ok +DROP TABLE decimals; + +statement ok +CREATE TABLE decimals_batch (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_batch VALUES (1), (2), (3), (4), (5); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_batch; +---- +1 5 + +statement ok +DROP TABLE decimals_batch; + +statement ok +CREATE TABLE decimals_empty (value DECIMAL(10, 0)); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_empty; +---- +NULL NULL + +statement ok +DROP TABLE decimals_empty; + +# min_decimal_all_nulls, max_decimal_all_nulls +statement ok +CREATE TABLE decimals_all_nulls (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_all_nulls VALUES (NULL), (NULL), (NULL), (NULL), (NULL), (NULL); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_all_nulls; +---- +NULL NULL + +statement ok +DROP TABLE decimals_all_nulls; + +# min_decimal_with_nulls, max_decimal_with_nulls +statement ok +CREATE TABLE decimals_with_nulls (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_with_nulls VALUES (1), (NULL), (3), (4), (5); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_with_nulls; +---- +1 5 + +statement ok +DROP TABLE decimals_with_nulls; + +statement ok +CREATE TABLE decimals_error (value DECIMAL(10, 2)); + +statement ok +INSERT INTO decimals_error VALUES (123.00), (arrow_cast(124.001, 'Decimal128(10, 3)')); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_error; +---- +123 124 + +statement ok +DROP TABLE decimals_error; + +statement ok +CREATE TABLE decimals_agg (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_agg VALUES (1), (2), (3), (4), (5); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_agg; +---- +1 5 + +statement ok +DROP TABLE decimals_agg; + +# min_i32, max_i32 +statement ok +CREATE TABLE integers (value INT); + +statement ok +INSERT INTO integers VALUES (1), (2), (3), (4), (5); + +query II +SELECT MIN(value), MAX(value) FROM integers +---- +1 5 + +statement ok +DROP TABLE integers; + +# min_utf8, max_utf8 +statement ok +CREATE TABLE strings (value TEXT); + +statement ok +INSERT INTO strings VALUES ('d'), ('a'), ('c'), ('b'); + +query TT +SELECT MIN(value), MAX(value) FROM strings +---- +a d + +statement ok +DROP TABLE strings; + +# min_i32_with_nulls, max_i32_with_nulls +statement ok +CREATE TABLE integers_with_nulls (value INT); + +statement ok +INSERT INTO integers_with_nulls VALUES (1), (NULL), (3), (4), (5); + +query II +SELECT MIN(value), MAX(value) FROM integers_with_nulls +---- +1 5 + +statement ok +DROP TABLE integers_with_nulls; + +# min_i32_all_nulls, max_i32_all_nulls +statement ok +CREATE TABLE integers_all_nulls (value INT); + +query II +SELECT MIN(value), MAX(value) FROM integers_all_nulls +---- +NULL NULL + +statement ok +DROP TABLE integers_all_nulls; + +# min_u32, max_u32 +statement ok +CREATE TABLE uintegers (value INT UNSIGNED); + +statement ok +INSERT INTO uintegers VALUES (1), (2), (3), (4), (5); + +query II +SELECT MIN(value), MAX(value) FROM uintegers +---- +1 5 + +statement ok +DROP TABLE uintegers; + +# min_f32, max_f32 +statement ok +CREATE TABLE floats (value FLOAT); + +statement ok +INSERT INTO floats VALUES (1.0), (2.0), (3.0), (4.0), (5.0); + +query RR +SELECT MIN(value), MAX(value) FROM floats +---- +1 5 + +statement ok +DROP TABLE floats; + +# min_f64, max_f64 +statement ok +CREATE TABLE doubles (value DOUBLE); + +statement ok +INSERT INTO doubles VALUES (1.0), (2.0), (3.0), (4.0), (5.0); + +query RR +SELECT MIN(value), MAX(value) FROM doubles +---- +1 5 + +statement ok +DROP TABLE doubles; + +# min_date, max_date +statement ok +CREATE TABLE dates (value DATE); + +statement ok +INSERT INTO dates VALUES ('1970-01-02'), ('1970-01-03'), ('1970-01-04'), ('1970-01-05'), ('1970-01-06'); + +query DD +SELECT MIN(value), MAX(value) FROM dates +---- +1970-01-02 1970-01-06 + +statement ok +DROP TABLE dates; + +# min_seconds, max_seconds +statement ok +CREATE TABLE times (value TIME); + +statement ok +INSERT INTO times VALUES ('00:00:01'), ('00:00:02'), ('00:00:03'), ('00:00:04'), ('00:00:05'); + +query DD +SELECT MIN(value), MAX(value) FROM times +---- +00:00:01 00:00:05 + +statement ok +DROP TABLE times; + +# min_milliseconds, max_milliseconds +statement ok +CREATE TABLE time32millisecond (value TIME); + +statement ok +INSERT INTO time32millisecond VALUES ('00:00:00.001'), ('00:00:00.002'), ('00:00:00.003'), ('00:00:00.004'), ('00:00:00.005'); + +query DD +SELECT MIN(value), MAX(value) FROM time32millisecond +---- +00:00:00.001 00:00:00.005 + +statement ok +DROP TABLE time32millisecond; + +# min_microseconds, max_microseconds +statement ok +CREATE TABLE time64microsecond (value TIME); + +statement ok +INSERT INTO time64microsecond VALUES ('00:00:00.000001'), ('00:00:00.000002'), ('00:00:00.000003'), ('00:00:00.000004'), ('00:00:00.000005'); + +query DD +SELECT MIN(value), MAX(value) FROM time64microsecond +---- +00:00:00.000001 00:00:00.000005 + +statement ok +DROP TABLE time64microsecond; + +# min_nanoseconds, max_nanoseconds +statement ok +CREATE TABLE time64nanosecond (value TIME); + +statement ok +INSERT INTO time64nanosecond VALUES ('00:00:00.000000001'), ('00:00:00.000000002'), ('00:00:00.000000003'), ('00:00:00.000000004'), ('00:00:00.000000005'); + +query DD +SELECT MIN(value), MAX(value) FROM time64nanosecond +---- +00:00:00.000000001 00:00:00.000000005 + +statement ok +DROP TABLE time64nanosecond; + +# min_timestamp, max_timestamp +statement ok +CREATE TABLE timestampmicrosecond (value TIMESTAMP); + +statement ok +INSERT INTO timestampmicrosecond VALUES ('1970-01-01 00:00:00.000001'), ('1970-01-01 00:00:00.000002'), ('1970-01-01 00:00:00.000003'), ('1970-01-01 00:00:00.000004'), ('1970-01-01 00:00:00.000005'); + +query PP +SELECT MIN(value), MAX(value) FROM timestampmicrosecond +---- +1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000005 + +statement ok +DROP TABLE timestampmicrosecond; + +# max_bool +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (false), (false); + +query B +SELECT MAX(value) FROM max_bool +---- +false + +statement ok +DROP TABLE max_bool; + +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (true), (true); + +query B +SELECT MAX(value) FROM max_bool +---- +true + +statement ok +DROP TABLE max_bool; + +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (false), (true), (false); + +query B +SELECT MAX(value) FROM max_bool +---- +true + +statement ok +DROP TABLE max_bool; + +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (true), (false), (true); + +query B +SELECT MAX(value) FROM max_bool +---- +true + +statement ok +DROP TABLE max_bool; + +# min_bool +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (false), (false); + +query B +SELECT MIN(value) FROM min_bool +---- +false + +statement ok +DROP TABLE min_bool; + +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (true), (true); + +query B +SELECT MIN(value) FROM min_bool +---- +true + +statement ok +DROP TABLE min_bool; + +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (false), (true), (false); + +query B +SELECT MIN(value) FROM min_bool +---- +false + +statement ok +DROP TABLE min_bool; + +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (true), (false), (true); + +query B +SELECT MIN(value) FROM min_bool +---- +false + +statement ok +DROP TABLE min_bool; + +################# +# Min_Max End # +################# + statement ok create table bool_aggregate_functions ( c1 boolean not null, From ead66acc3dff75a1e55f5cc3c2a9f0264b7ae5dd Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 16 May 2024 18:15:04 +0800 Subject: [PATCH 10/15] Implement unparse `IsNotFalse` to String (#10538) * support to unparse IsNotFalse * reordering expressions in pattern matching --- datafusion/sql/src/unparser/expr.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 23e3d9ab3594..cd45cf990863 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -356,6 +356,9 @@ impl Unparser<'_> { asc: _, nulls_first: _, }) => plan_err!("Sort expression should be handled by expr_to_unparsed"), + Expr::IsNull(expr) => { + Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql(expr)?))) + } Expr::IsNotNull(expr) => { Ok(ast::Expr::IsNotNull(Box::new(self.expr_to_sql(expr)?))) } @@ -368,6 +371,9 @@ impl Unparser<'_> { Expr::IsFalse(expr) => { Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql(expr)?))) } + Expr::IsNotFalse(expr) => { + Ok(ast::Expr::IsNotFalse(Box::new(self.expr_to_sql(expr)?))) + } Expr::IsUnknown(expr) => { Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?))) } @@ -391,10 +397,6 @@ impl Unparser<'_> { Expr::ScalarVariable(_, _) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") } - Expr::IsNull(expr) => { - Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotFalse(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), Expr::GetIndexedField(_) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") } @@ -1116,6 +1118,10 @@ mod tests { (col("a") + col("b")).gt(lit(4)).is_false(), r#"(("a" + "b") > 4) IS FALSE"#, ), + ( + (col("a") + col("b")).gt(lit(4)).is_not_false(), + r#"(("a" + "b") > 4) IS NOT FALSE"#, + ), ( (col("a") + col("b")).gt(lit(4)).is_unknown(), r#"(("a" + "b") > 4) IS UNKNOWN"#, From 7535d93fa631f0fc42ca1521bc257ca20480b246 Mon Sep 17 00:00:00 2001 From: Xin Li <33629085+xinlifoobar@users.noreply.github.com> Date: Thu, 16 May 2024 18:16:20 +0800 Subject: [PATCH 11/15] Implement Unparse TryCast Expr --> String Support (#10542) * TryCast Expr --> String Support * Fix format --- datafusion/sql/src/unparser/expr.rs | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index cd45cf990863..a9bfed575701 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -26,7 +26,7 @@ use datafusion_common::{ }; use datafusion_expr::{ expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, - Between, BinaryExpr, Case, Cast, Expr, Like, Operator, + Between, BinaryExpr, Case, Cast, Expr, Like, Operator, TryCast, }; use sqlparser::ast::{ self, Expr as AstExpr, Function, FunctionArg, Ident, UnaryOperator, @@ -400,7 +400,14 @@ impl Unparser<'_> { Expr::GetIndexedField(_) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") } - Expr::TryCast(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + Expr::TryCast(TryCast { expr, data_type }) => { + let inner_expr = self.expr_to_sql(expr)?; + Ok(ast::Expr::TryCast { + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }) + } Expr::Wildcard { qualifier: _ } => { not_impl_err!("Unsupported Expr conversion: {expr:?}") } @@ -867,8 +874,9 @@ mod tests { use datafusion_expr::{ case, col, exists, expr::{AggregateFunction, AggregateFunctionDefinition}, - lit, not, not_exists, table_scan, when, wildcard, ColumnarValue, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, + lit, not, not_exists, table_scan, try_cast, when, wildcard, ColumnarValue, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, + WindowFunctionDefinition, }; use crate::unparser::dialect::CustomDialect; @@ -1144,6 +1152,14 @@ mod tests { not_exists(Arc::new(dummy_logical_plan.clone())), r#"NOT EXISTS (SELECT "t"."a" FROM "t" WHERE ("t"."a" = 1))"#, ), + ( + try_cast(col("a"), DataType::Date64), + r#"TRY_CAST("a" AS DATETIME)"#, + ), + ( + try_cast(col("a"), DataType::UInt32), + r#"TRY_CAST("a" AS INTEGER UNSIGNED)"#, + ), ]; for (expr, expected) in tests { From 410b04fe29f72ecea128a252ac4f51725f8ae7f6 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Thu, 16 May 2024 20:08:47 +0800 Subject: [PATCH 12/15] Implement unparse `Placeholder` to String (#10540) --- datafusion/sql/src/unparser/expr.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index a9bfed575701..c871d1f21ffa 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -414,8 +414,8 @@ impl Unparser<'_> { Expr::GroupingSet(_) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") } - Expr::Placeholder(_) => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") + Expr::Placeholder(p) => { + Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string()))) } Expr::OuterReferenceColumn(_, _) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") @@ -874,8 +874,8 @@ mod tests { use datafusion_expr::{ case, col, exists, expr::{AggregateFunction, AggregateFunctionDefinition}, - lit, not, not_exists, table_scan, try_cast, when, wildcard, ColumnarValue, - ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, + lit, not, not_exists, placeholder, table_scan, try_cast, when, wildcard, + ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; @@ -1160,6 +1160,7 @@ mod tests { try_cast(col("a"), DataType::UInt32), r#"TRY_CAST("a" AS INTEGER UNSIGNED)"#, ), + (col("x").eq(placeholder("$1")), r#"("x" = $1)"#), ]; for (expr, expected) in tests { From 842f3933e3496a022984c2a37254475a3bcde1bf Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 16 May 2024 21:23:33 +0800 Subject: [PATCH 13/15] Convert OuterReferenceColumn to a Column sql node (#10544) --- datafusion/sql/src/unparser/expr.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index c871d1f21ffa..416ab03d1fa9 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -417,9 +417,7 @@ impl Unparser<'_> { Expr::Placeholder(p) => { Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string()))) } - Expr::OuterReferenceColumn(_, _) => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") - } + Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), Expr::Unnest(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), } } @@ -874,9 +872,9 @@ mod tests { use datafusion_expr::{ case, col, exists, expr::{AggregateFunction, AggregateFunctionDefinition}, - lit, not, not_exists, placeholder, table_scan, try_cast, when, wildcard, - ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, - WindowFunctionDefinition, + lit, not, not_exists, out_ref_col, placeholder, table_scan, try_cast, when, + wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, + WindowFrame, WindowFunctionDefinition, }; use crate::unparser::dialect::CustomDialect; @@ -1161,6 +1159,10 @@ mod tests { r#"TRY_CAST("a" AS INTEGER UNSIGNED)"#, ), (col("x").eq(placeholder("$1")), r#"("x" = $1)"#), + ( + out_ref_col(DataType::Int32, "t.a").gt(lit(1)), + r#"("t"."a" > 1)"#, + ), ]; for (expr, expected) in tests { From 87169f06ab590f20bd03b1be504a2119ddca6d68 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 16 May 2024 20:08:39 -0400 Subject: [PATCH 14/15] Stop copying LogicalPlan and Exprs in `PushDownFilter` (#10444) --- datafusion/expr/src/logical_plan/plan.rs | 10 + datafusion/optimizer/src/push_down_filter.rs | 623 ++++++++++--------- 2 files changed, 343 insertions(+), 290 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ddf075c2c27b..4872e5acda5e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2407,6 +2407,16 @@ pub enum Distinct { On(DistinctOn), } +impl Distinct { + /// return a reference to the nodes input + pub fn input(&self) -> &Arc { + match self { + Distinct::All(input) => input, + Distinct::On(DistinctOn { input, .. }) => input, + } + } +} + /// Removes duplicate rows from the input #[derive(Clone, PartialEq, Eq, Hash)] pub struct DistinctOn { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 57b38bd0d0fd..b684b5490342 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -14,6 +14,7 @@ //! [`PushDownFilter`] applies filters as early as possible +use indexmap::IndexSet; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -23,10 +24,9 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - internal_err, plan_datafusion_err, qualified_name, Column, DFSchema, DFSchemaRef, + internal_err, plan_err, qualified_name, Column, DFSchema, DFSchemaRef, JoinConstraint, Result, }; -use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{ @@ -131,7 +131,8 @@ use crate::{OptimizerConfig, OptimizerRule}; #[derive(Default)] pub struct PushDownFilter {} -/// For a given JOIN logical plan, determine whether each side of the join is preserved. +/// For a given JOIN type, determine whether each side of the join is preserved. +/// /// We say a join side is preserved if the join returns all or a subset of the rows from /// the relevant side, such that each row of the output table directly maps to a row of /// the preserved input table. If a table is not preserved, it can provide extra null rows. @@ -150,44 +151,33 @@ pub struct PushDownFilter {} /// non-preserved side it can be more tricky. /// /// Returns a tuple of booleans - (left_preserved, right_preserved). -fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((true, false)), - JoinType::Right => Ok((false, true)), - JoinType::Full => Ok((false, false)), - // No columns from the right side of the join can be referenced in output - // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)), - // No columns from the left side of the join can be referenced in output - // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)), - }, - LogicalPlan::CrossJoin(_) => Ok((true, true)), - _ => internal_err!("lr_is_preserved only valid for JOIN nodes"), +fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { + match join_type { + JoinType::Inner => Ok((true, true)), + JoinType::Left => Ok((true, false)), + JoinType::Right => Ok((false, true)), + JoinType::Full => Ok((false, false)), + // No columns from the right side of the join can be referenced in output + // predicates for semi/anti joins, so whether we specify t/f doesn't matter. + JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)), + // No columns from the left side of the join can be referenced in output + // predicates for semi/anti joins, so whether we specify t/f doesn't matter. + JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)), } } /// For a given JOIN logical plan, determine whether each side of the join is preserved /// in terms on join filtering. -/// /// Predicates from join filter can only be pushed to preserved join side. -fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((false, true)), - JoinType::Right => Ok((true, false)), - JoinType::Full => Ok((false, false)), - JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)), - JoinType::LeftAnti => Ok((false, true)), - JoinType::RightAnti => Ok((true, false)), - }, - LogicalPlan::CrossJoin(_) => { - internal_err!("on_lr_is_preserved cannot be applied to CROSSJOIN nodes") - } - _ => internal_err!("on_lr_is_preserved only valid for JOIN nodes"), +fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { + match join_type { + JoinType::Inner => Ok((true, true)), + JoinType::Left => Ok((false, true)), + JoinType::Right => Ok((true, false)), + JoinType::Full => Ok((false, false)), + JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)), + JoinType::LeftAnti => Ok((false, true)), + JoinType::RightAnti => Ok((true, false)), } } @@ -400,23 +390,20 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option, - infer_predicates: Vec, - join_plan: &LogicalPlan, - left: &LogicalPlan, - right: &LogicalPlan, + inferred_join_predicates: Vec, + mut join: Join, on_filter: Vec, - is_inner_join: bool, ) -> Result> { - let on_filter_empty = on_filter.is_empty(); + let is_inner_join = join.join_type == JoinType::Inner; // Get pushable predicates from current optimizer state - let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?; + let (left_preserved, right_preserved) = lr_is_preserved(join.join_type)?; // The predicates can be divided to three categories: // 1) can push through join to its children(left or right) // 2) can be converted to join conditions if the join type is Inner // 3) should be kept as filter conditions - let left_schema = left.schema(); - let right_schema = right.schema(); + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); let mut left_push = vec![]; let mut right_push = vec![]; let mut keep_predicates = vec![]; @@ -438,7 +425,7 @@ fn push_down_all_join( } // For infer predicates, if they can not push through join, just drop them - for predicate in infer_predicates { + for predicate in inferred_join_predicates { if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { left_push.push(predicate); } else if right_preserved @@ -449,7 +436,7 @@ fn push_down_all_join( } if !on_filter.is_empty() { - let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join_plan)?; + let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?; for on in on_filter { if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? { left_push.push(on) @@ -474,46 +461,29 @@ fn push_down_all_join( right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema)); } - let left = match conjunction(left_push) { - Some(predicate) => { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?) - } - None => left.clone(), - }; - let right = match conjunction(right_push) { - Some(predicate) => { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?) - } - None => right.clone(), - }; - // Create a new Join with the new `left` and `right` - // - // expressions() output for Join is a vector consisting of - // 1. join keys - columns mentioned in ON clause - // 2. optional predicate - in case join filter is not empty, - // it always will be the last element, otherwise result - // vector will contain only join keys (without additional - // element representing filter). - let mut exprs = join_plan.expressions(); - if !on_filter_empty { - exprs.pop(); - } - exprs.extend(join_conditions.into_iter().reduce(Expr::and)); - let plan = join_plan.with_new_exprs(exprs, vec![left, right])?; - - // wrap the join on the filter whose predicates must be kept - match conjunction(keep_predicates) { - Some(predicate) => { - let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?; - Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan))) - } - None => Ok(Transformed::no(plan)), + if let Some(predicate) = conjunction(left_push) { + join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); } + if let Some(predicate) = conjunction(right_push) { + join.right = + Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?)); + } + + // Add any new join conditions as the non join predicates + join.filter = conjunction(join_conditions); + + // wrap the join on the filter whose predicates must be kept, if any + let plan = LogicalPlan::Join(join); + let plan = if let Some(predicate) = conjunction(keep_predicates) { + LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?) + } else { + plan + }; + Ok(Transformed::yes(plan)) } fn push_down_join( - plan: &LogicalPlan, - join: &Join, + join: Join, parent_predicate: Option<&Expr>, ) -> Result> { // Split the parent predicate into individual conjunctive parts. @@ -526,93 +496,102 @@ fn push_down_join( .as_ref() .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone())); - let mut is_inner_join = false; - let infer_predicates = if join.join_type == JoinType::Inner { - is_inner_join = true; - - // Only allow both side key is column. - let join_col_keys = join - .on - .iter() - .filter_map(|(l, r)| { - let left_col = l.try_as_col().cloned()?; - let right_col = r.try_as_col().cloned()?; - Some((left_col, right_col)) - }) - .collect::>(); - - // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down - // For inner joins, duplicate filters for joined columns so filters can be pushed down - // to both sides. Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - // This logic should also been applied to conditions in JOIN ON clause - predicates - .iter() - .chain(on_filters.iter()) - .filter_map(|predicate| { - let mut join_cols_to_replace = HashMap::new(); - - let columns = match predicate.to_columns() { - Ok(columns) => columns, - Err(e) => return Some(Err(e)), - }; + // Are there any new join predicates that can be inferred from the filter expressions? + let inferred_join_predicates = + infer_join_predicates(&join, &predicates, &on_filters)?; - for col in columns.iter() { - for (l, r) in join_col_keys.iter() { - if col == l { - join_cols_to_replace.insert(col, r); - break; - } else if col == r { - join_cols_to_replace.insert(col, l); - break; - } - } - } + if on_filters.is_empty() + && predicates.is_empty() + && inferred_join_predicates.is_empty() + { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } - if join_cols_to_replace.is_empty() { - return None; - } + push_down_all_join(predicates, inferred_join_predicates, join, on_filters) +} - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; +/// Extracts any equi-join join predicates from the given filter expressions. +/// +/// Parameters +/// * `join` the join in question +/// +/// * `predicates` the pushed down filter expression +/// +/// * `on_filters` filters from the join ON clause that have not already been +/// identified as join predicates +/// +fn infer_join_predicates( + join: &Join, + predicates: &[Expr], + on_filters: &[Expr], +) -> Result> { + if join.join_type != JoinType::Inner { + return Ok(vec![]); + } - Some(Ok(join_side_predicate)) - }) - .collect::>>()? - } else { - vec![] - }; + // Only allow both side key is column. + let join_col_keys = join + .on + .iter() + .filter_map(|(l, r)| { + let left_col = l.try_as_col()?; + let right_col = r.try_as_col()?; + Some((left_col, right_col)) + }) + .collect::>(); - if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() { - return Ok(Transformed::no(plan.clone())); - } + // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down + // For inner joins, duplicate filters for joined columns so filters can be pushed down + // to both sides. Take the following query as an example: + // + // ```sql + // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 + // ``` + // + // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while + // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. + // + // Join clauses with `Using` constraints also take advantage of this logic to make sure + // predicates reference the shared join columns are pushed to both sides. + // This logic should also been applied to conditions in JOIN ON clause + predicates + .iter() + .chain(on_filters.iter()) + .filter_map(|predicate| { + let mut join_cols_to_replace = HashMap::new(); + + let columns = match predicate.to_columns() { + Ok(columns) => columns, + Err(e) => return Some(Err(e)), + }; + + for col in columns.iter() { + for (l, r) in join_col_keys.iter() { + if col == *l { + join_cols_to_replace.insert(col, *r); + break; + } else if col == *r { + join_cols_to_replace.insert(col, *l); + break; + } + } + } - match push_down_all_join( - predicates, - infer_predicates, - plan, - &join.left, - &join.right, - on_filters, - is_inner_join, - ) { - Ok(plan) => Ok(Transformed::yes(plan.data)), - Err(e) => Err(e), - } + if join_cols_to_replace.is_empty() { + return None; + } + + let join_side_predicate = + match replace_col(predicate.clone(), &join_cols_to_replace) { + Ok(p) => p, + Err(e) => { + return Some(Err(e)); + } + }; + + Some(Ok(join_side_predicate)) + }) + .collect::>>() } impl OptimizerRule for PushDownFilter { @@ -641,46 +620,57 @@ impl OptimizerRule for PushDownFilter { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let filter = match plan { - LogicalPlan::Filter(ref filter) => filter, - LogicalPlan::Join(ref join) => return push_down_join(&plan, join, None), - _ => return Ok(Transformed::no(plan)), + if let LogicalPlan::Join(join) = plan { + return push_down_join(join, None); + }; + + let plan_schema = plan.schema().clone(); + + let LogicalPlan::Filter(mut filter) = plan else { + return Ok(Transformed::no(plan)); }; - let child_plan = filter.input.as_ref(); - let new_plan = match child_plan { - LogicalPlan::Filter(ref child_filter) => { - let parents_predicates = split_conjunction(&filter.predicate); - let set: HashSet<&&Expr> = parents_predicates.iter().collect(); + match unwrap_arc(filter.input) { + LogicalPlan::Filter(child_filter) => { + let parents_predicates = split_conjunction_owned(filter.predicate); + // remove duplicated filters + let child_predicates = split_conjunction_owned(child_filter.predicate); let new_predicates = parents_predicates - .iter() - .chain( - split_conjunction(&child_filter.predicate) - .iter() - .filter(|e| !set.contains(e)), - ) - .map(|e| (*e).clone()) + .into_iter() + .chain(child_predicates) + // use IndexSet to remove dupes while preserving predicate order + .collect::>() + .into_iter() .collect::>(); - let new_predicate = conjunction(new_predicates).ok_or_else(|| { - plan_datafusion_err!("at least one expression exists") - })?; + + let Some(new_predicate) = conjunction(new_predicates) else { + return plan_err!("at least one expression exists"); + }; let new_filter = LogicalPlan::Filter(Filter::try_new( new_predicate, - child_filter.input.clone(), + child_filter.input, )?); - self.rewrite(new_filter, _config)?.data + self.rewrite(new_filter, _config) } - LogicalPlan::Repartition(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Sort(_) => { - let new_filter = plan.with_new_exprs( - plan.expressions(), - vec![child_plan.inputs()[0].clone()], - )?; - child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])? + LogicalPlan::Repartition(repartition) => { + let new_filter = + Filter::try_new(filter.predicate, repartition.input.clone()) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Repartition(repartition), new_filter) } - LogicalPlan::SubqueryAlias(ref subquery_alias) => { + LogicalPlan::Distinct(distinct) => { + let new_filter = + Filter::try_new(filter.predicate, distinct.input().clone()) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Distinct(distinct), new_filter) + } + LogicalPlan::Sort(sort) => { + let new_filter = Filter::try_new(filter.predicate, sort.input.clone()) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Sort(sort), new_filter) + } + LogicalPlan::SubqueryAlias(subquery_alias) => { let mut replace_map = HashMap::new(); for (i, (qualifier, field)) in subquery_alias.input.schema().iter().enumerate() @@ -692,15 +682,15 @@ impl OptimizerRule for PushDownFilter { Expr::Column(Column::new(qualifier.cloned(), field.name())), ); } - let new_predicate = - replace_cols_by_name(filter.predicate.clone(), &replace_map)?; + let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; + let new_filter = LogicalPlan::Filter(Filter::try_new( new_predicate, subquery_alias.input.clone(), )?); - child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])? + insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) } - LogicalPlan::Projection(ref projection) => { + LogicalPlan::Projection(projection) => { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile // predicates that are not used in the filter. However, we should re-writes all predicate expressions. // collect projection. @@ -711,10 +701,7 @@ impl OptimizerRule for PushDownFilter { .enumerate() .map(|(i, (qualifier, field))| { // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), - expr => expr.clone(), - }; + let expr = projection.expr[i].clone().unalias(); (qualified_name(qualifier, field.name()), expr) }) @@ -741,23 +728,24 @@ impl OptimizerRule for PushDownFilter { )?); match conjunction(keep_predicates) { - None => child_plan.with_new_exprs( - child_plan.expressions(), - vec![new_filter], - )?, - Some(keep_predicate) => { - let child_plan = child_plan.with_new_exprs( - child_plan.expressions(), - vec![new_filter], - )?; - LogicalPlan::Filter(Filter::try_new( - keep_predicate, - Arc::new(child_plan), - )?) - } + None => insert_below( + LogicalPlan::Projection(projection), + new_filter, + ), + Some(keep_predicate) => insert_below( + LogicalPlan::Projection(projection), + new_filter, + )? + .map_data(|child_plan| { + Filter::try_new(keep_predicate, Arc::new(child_plan)) + .map(LogicalPlan::Filter) + }), } } - None => return Ok(Transformed::no(plan)), + None => { + filter.input = Arc::new(LogicalPlan::Projection(projection)); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } } } LogicalPlan::Union(ref union) => { @@ -780,12 +768,12 @@ impl OptimizerRule for PushDownFilter { input.clone(), )?))) } - LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::Union(Union { inputs, - schema: plan.schema().clone(), - }) + schema: plan_schema.clone(), + }))) } - LogicalPlan::Aggregate(ref agg) => { + LogicalPlan::Aggregate(agg) => { // We can push down Predicate which in groupby_expr. let group_expr_columns = agg .group_expr @@ -818,49 +806,33 @@ impl OptimizerRule for PushDownFilter { .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) .collect::>>()?; - let child = match conjunction(replaced_push_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - agg.input.clone(), - )?), - None => (*agg.input).clone(), - }; - let new_agg = filter - .input - .with_new_exprs(filter.input.expressions(), vec![child])?; - match conjunction(keep_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_agg), - )?), - None => new_agg, - } - } - LogicalPlan::Join(ref join) => { - push_down_join( - &unwrap_arc(filter.clone().input), - join, - Some(&filter.predicate), - )? - .data + let agg_input = agg.input.clone(); + Transformed::yes(LogicalPlan::Aggregate(agg)) + .transform_data(|new_plan| { + // If we have a filter to push, we push it down to the input of the aggregate + if let Some(predicate) = conjunction(replaced_push_predicates) { + let new_filter = make_filter(predicate, agg_input)?; + insert_below(new_plan, new_filter) + } else { + Ok(Transformed::no(new_plan)) + } + })? + .map_data(|child_plan| { + // if there are any remaining predicates we can't push, add them + // back as a filter + if let Some(predicate) = conjunction(keep_predicates) { + make_filter(predicate, Arc::new(child_plan)) + } else { + Ok(child_plan) + } + }) } - LogicalPlan::CrossJoin(ref cross_join) => { + LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), + LogicalPlan::CrossJoin(cross_join) => { let predicates = split_conjunction_owned(filter.predicate.clone()); - let join = convert_cross_join_to_inner_join(cross_join.clone())?; - let join_plan = LogicalPlan::Join(join); - let inputs = join_plan.inputs(); - let left = inputs[0]; - let right = inputs[1]; - let plan = push_down_all_join( - predicates, - vec![], - &join_plan, - left, - right, - vec![], - true, - )?; - convert_to_cross_join_if_beneficial(plan.data)? + let join = convert_cross_join_to_inner_join(cross_join)?; + let plan = push_down_all_join(predicates, vec![], join, vec![])?; + convert_to_cross_join_if_beneficial(plan.data) } LogicalPlan::TableScan(ref scan) => { let filter_predicates = split_conjunction(&filter.predicate); @@ -901,25 +873,47 @@ impl OptimizerRule for PushDownFilter { fetch: scan.fetch, }); - match conjunction(new_predicate) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_scan), - )?), - None => new_scan, - } + Transformed::yes(new_scan).transform_data(|new_scan| { + if let Some(predicate) = conjunction(new_predicate) { + make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes) + } else { + Ok(Transformed::no(new_scan)) + } + }) } - LogicalPlan::Extension(ref extension_plan) => { + LogicalPlan::Extension(extension_plan) => { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); - let predicates = split_conjunction_owned(filter.predicate.clone()); + // determine if we can push any predicates down past the extension node + + // each element is true for push, false to keep + let predicate_push_or_keep = split_conjunction(&filter.predicate) + .iter() + .map(|expr| { + let cols = expr.to_columns()?; + if cols.iter().any(|c| prevent_cols.contains(&c.name)) { + Ok(false) // No push (keep) + } else { + Ok(true) // push + } + }) + .collect::>>()?; + // all predicates are kept, no changes needed + if predicate_push_or_keep.iter().all(|&x| !x) { + filter.input = Arc::new(LogicalPlan::Extension(extension_plan)); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + + // going to push some predicates down, so split the predicates let mut keep_predicates = vec![]; let mut push_predicates = vec![]; - for expr in predicates { - let cols = expr.to_columns()?; - if cols.iter().any(|c| prevent_cols.contains(&c.name)) { + for (push, expr) in predicate_push_or_keep + .into_iter() + .zip(split_conjunction_owned(filter.predicate).into_iter()) + { + if !push { keep_predicates.push(expr); } else { push_predicates.push(expr); @@ -941,22 +935,65 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. + let child_plan = LogicalPlan::Extension(extension_plan); let new_extension = child_plan.with_new_exprs(child_plan.expressions(), new_children)?; - match conjunction(keep_predicates) { + let new_plan = match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, Arc::new(new_extension), )?), None => new_extension, - } + }; + Ok(Transformed::yes(new_plan)) } - _ => return Ok(Transformed::no(plan)), - }; + child => { + filter.input = Arc::new(child); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } + } + } +} + +/// Creates a new LogicalPlan::Filter node. +pub fn make_filter(predicate: Expr, input: Arc) -> Result { + Filter::try_new(predicate, input).map(LogicalPlan::Filter) +} - Ok(Transformed::yes(new_plan)) +/// Replace the existing child of the single input node with `new_child`. +/// +/// Starting: +/// ```text +/// plan +/// child +/// ``` +/// +/// Ending: +/// ```text +/// plan +/// new_child +/// ``` +fn insert_below( + plan: LogicalPlan, + new_child: LogicalPlan, +) -> Result> { + let mut new_child = Some(new_child); + let transformed_plan = plan.map_children(|_child| { + if let Some(new_child) = new_child.take() { + Ok(Transformed::yes(new_child)) + } else { + // already took the new child + internal_err!("node had more than one input") + } + })?; + + // make sure we did the actual replacement + if new_child.is_some() { + return internal_err!("node had no inputs"); } + + Ok(transformed_plan) } impl PushDownFilter { @@ -985,21 +1022,27 @@ fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { /// Converts the given inner join with an empty equality predicate and an /// empty filter condition to a cross join. -fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { - if let LogicalPlan::Join(join) = &plan { +fn convert_to_cross_join_if_beneficial( + plan: LogicalPlan, +) -> Result> { + match plan { // Can be converted back to cross join - if join.on.is_empty() && join.filter.is_none() { - return LogicalPlanBuilder::from(join.left.as_ref().clone()) - .cross_join(join.right.as_ref().clone())? - .build(); + LogicalPlan::Join(join) if join.on.is_empty() && join.filter.is_none() => { + LogicalPlanBuilder::from(unwrap_arc(join.left)) + .cross_join(unwrap_arc(join.right))? + .build() + .map(Transformed::yes) } - } else if let LogicalPlan::Filter(filter) = &plan { - let new_input = - convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?; - return Filter::try_new(filter.predicate.clone(), Arc::new(new_input)) - .map(LogicalPlan::Filter); + LogicalPlan::Filter(filter) => convert_to_cross_join_if_beneficial(unwrap_arc( + filter.input, + ))? + .transform_data(|child_plan| { + Filter::try_new(filter.predicate, Arc::new(child_plan)) + .map(LogicalPlan::Filter) + .map(Transformed::yes) + }), + plan => Ok(Transformed::no(plan)), } - Ok(plan) } /// replaces columns by its name on the projection. From 32b63ff95a8c2cc2d56be0b265d8a13e0ca42f96 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 16 May 2024 20:09:00 -0400 Subject: [PATCH 15/15] Stop copying LogicalPlan and Exprs in `ScalarSubqueryToJoin` (#10489) --- .../optimizer/src/scalar_subquery_to_join.rs | 83 ++++++++++++------- 1 file changed, 54 insertions(+), 29 deletions(-) diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index b7fce68fb3cc..71692b934543 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -29,7 +29,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; @@ -50,7 +50,7 @@ impl ScalarSubqueryToJoin { /// # Arguments /// * `predicate` - A conjunction to split and search /// - /// Returns a tuple (subqueries, rewrite expression) + /// Returns a tuple (subqueries, alias) fn extract_subquery_exprs( &self, predicate: &Expr, @@ -71,19 +71,36 @@ impl ScalarSubqueryToJoin { impl OptimizerRule for ScalarSubqueryToJoin { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called ScalarSubqueryToJoin::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { match plan { LogicalPlan::Filter(filter) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !contains_scalar_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( &filter.predicate, config.alias_generator(), )?; if subqueries.is_empty() { - // regular filter, no subquery exists clause here - return Ok(None); + return internal_err!("Expected subqueries not found in filter"); } // iterate through all subqueries in predicate, turning each into a left join @@ -94,16 +111,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = rewrite_expr - .clone() .transform_up(|expr| { - if let Expr::Column(col) = &expr { - if let Some(map_expr) = - expr_check_map.get(&col.name) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) } else { Ok(Transformed::no(expr)) } @@ -113,15 +127,21 @@ impl OptimizerRule for ScalarSubqueryToJoin { cur_input = optimized_subquery; } else { // if we can't handle all of the subqueries then bail for now - return Ok(None); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); } } let new_plan = LogicalPlanBuilder::from(cur_input) .filter(rewrite_expr)? .build()?; - Ok(Some(new_plan)) + Ok(Transformed::yes(new_plan)) } LogicalPlan::Projection(projection) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !projection.expr.iter().any(contains_scalar_subquery) { + return Ok(Transformed::no(LogicalPlan::Projection(projection))); + } + let mut all_subqueryies = vec![]; let mut expr_to_rewrite_expr_map = HashMap::new(); let mut subquery_to_expr_map = HashMap::new(); @@ -135,8 +155,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); } if all_subqueryies.is_empty() { - // regular projection, no subquery exists clause here - return Ok(None); + return internal_err!("Expected subqueries not found in projection"); } // iterate through all subqueries in predicate, turning each into a left join let mut cur_input = projection.input.as_ref().clone(); @@ -153,14 +172,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { let new_expr = rewrite_expr .clone() .transform_up(|expr| { - if let Expr::Column(col) = &expr { - if let Some(map_expr) = + // replace column references with entry in map, if it exists + if let Some(map_expr) = + expr.try_as_col().and_then(|col| { expr_check_map.get(&col.name) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } + }) + { + Ok(Transformed::yes(map_expr.clone())) } else { Ok(Transformed::no(expr)) } @@ -172,7 +190,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } else { // if we can't handle all of the subqueries then bail for now - return Ok(None); + return Ok(Transformed::no(LogicalPlan::Projection(projection))); } } @@ -190,10 +208,10 @@ impl OptimizerRule for ScalarSubqueryToJoin { let new_plan = LogicalPlanBuilder::from(cur_input) .project(proj_exprs)? .build()?; - Ok(Some(new_plan)) + Ok(Transformed::yes(new_plan)) } - _ => Ok(None), + plan => Ok(Transformed::no(plan)), } } @@ -206,6 +224,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } +/// Returns true if the expression has a scalar subquery somewhere in it +/// false otherwise +fn contains_scalar_subquery(expr: &Expr) -> bool { + expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) + .expect("Inner is always Ok") +} + struct ExtractScalarSubQuery { sub_query_info: Vec<(Subquery, String)>, alias_gen: Arc,