From 5309d83e31ed7fbf96b287e0dda9b9723b90a5c2 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Mon, 11 Nov 2024 17:42:17 +0100 Subject: [PATCH 01/17] refactor: replace `instant` with `web-time` (#13355) See https://rustsec.org/advisories/RUSTSEC-2024-0384.html . --- datafusion-cli/Cargo.lock | 24 +++++++++++------------- datafusion/common/Cargo.toml | 2 +- datafusion/common/src/instant.rs | 4 ++-- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9983e247f9ac..b4d790ebb0f3 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1300,7 +1300,6 @@ dependencies = [ "half", "hashbrown 0.14.5", "indexmap", - "instant", "libc", "num_cpus", "object_store", @@ -1308,6 +1307,7 @@ dependencies = [ "paste", "sqlparser", "tokio", + "web-time", ] [[package]] @@ -2354,18 +2354,6 @@ dependencies = [ "hashbrown 0.15.1", ] -[[package]] -name = "instant" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" -dependencies = [ - "cfg-if", - "js-sys", - "wasm-bindgen", - "web-sys", -] - [[package]] name = "integer-encoding" version = "3.0.4" @@ -4275,6 +4263,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "winapi-util" version = "0.1.9" diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 0747672a18f6..c398fe97d2a4 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -67,7 +67,7 @@ sqlparser = { workspace = true } tokio = { workspace = true } [target.'cfg(target_family = "wasm")'.dependencies] -instant = { version = "0.1", features = ["wasm-bindgen"] } +web-time = "1.1.0" [dev-dependencies] rand = { workspace = true } diff --git a/datafusion/common/src/instant.rs b/datafusion/common/src/instant.rs index 6401bc29c942..42f21c061c0c 100644 --- a/datafusion/common/src/instant.rs +++ b/datafusion/common/src/instant.rs @@ -18,9 +18,9 @@ //! WASM-compatible `Instant` wrapper. #[cfg(target_family = "wasm")] -/// DataFusion wrapper around [`std::time::Instant`]. Uses [`instant::Instant`] +/// DataFusion wrapper around [`std::time::Instant`]. Uses [`web_time::Instant`] /// under `wasm` feature gate. It provides the same API as [`std::time::Instant`]. -pub type Instant = instant::Instant; +pub type Instant = web_time::Instant; #[allow(clippy::disallowed_types)] #[cfg(not(target_family = "wasm"))] From bab02b623119d4e4f9192b61ce6e654575ee73db Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 11 Nov 2024 18:53:19 +0100 Subject: [PATCH 02/17] Add stacker and recursive (#13310) * Add stacker and recursive * add test * fine tune set_expr_to_plan min stack size * format tomls * add recursive annotation to more custom recursive rules * fix comments * fix comment about min stack requirement --------- Co-authored-by: blaginin --- Cargo.toml | 1 + datafusion-cli/Cargo.lock | 47 +++++++++++++++++++ datafusion/common/Cargo.toml | 1 + datafusion/common/src/tree_node.rs | 20 ++++++++ datafusion/expr/Cargo.toml | 1 + datafusion/expr/src/logical_plan/tree_node.rs | 7 +++ datafusion/optimizer/Cargo.toml | 1 + datafusion/optimizer/src/analyzer/subquery.rs | 2 + .../optimizer/src/common_subexpr_eliminate.rs | 2 + .../optimizer/src/eliminate_cross_join.rs | 5 +- .../optimizer/src/optimize_projections/mod.rs | 7 +-- datafusion/physical-optimizer/Cargo.toml | 1 + .../src/aggregate_statistics.rs | 5 +- datafusion/sql/Cargo.toml | 1 + datafusion/sql/src/expr/mod.rs | 11 +++-- datafusion/sql/src/expr/value.rs | 2 - datafusion/sql/src/query.rs | 6 +++ datafusion/sql/src/set_expr.rs | 2 + 18 files changed, 109 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 91f09102ce48..0b5c74e15d13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -141,6 +141,7 @@ pbjson = { version = "0.7.0" } prost = "0.13.1" prost-derive = "0.13.1" rand = "0.8" +recursive = "0.1.1" regex = "1.8" rstest = "0.23.0" serde_json = "1" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index b4d790ebb0f3..02bd01a49905 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1305,6 +1305,7 @@ dependencies = [ "object_store", "parquet", "paste", + "recursive", "sqlparser", "tokio", "web-time", @@ -1353,6 +1354,7 @@ dependencies = [ "datafusion-physical-expr-common", "indexmap", "paste", + "recursive", "serde_json", "sqlparser", "strum 0.26.3", @@ -1482,6 +1484,7 @@ dependencies = [ "itertools", "log", "paste", + "recursive", "regex", "regex-syntax", ] @@ -1537,6 +1540,7 @@ dependencies = [ "datafusion-physical-expr", "datafusion-physical-plan", "itertools", + "recursive", ] [[package]] @@ -1583,6 +1587,7 @@ dependencies = [ "datafusion-expr", "indexmap", "log", + "recursive", "regex", "sqlparser", "strum 0.26.3", @@ -3030,6 +3035,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psm" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa37f80ca58604976033fae9515a8a2989fc13797d953f7c04fb8fa36a11f205" +dependencies = [ + "cc", +] + [[package]] name = "quad-rand" version = "0.2.2" @@ -3144,6 +3158,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "redox_syscall" version = "0.5.7" @@ -3695,6 +3729,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799c883d55abdb5e98af1a7b3f23b9b6de8ecada0ecac058672d7635eb48ca7b" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + [[package]] name = "static_assertions" version = "1.1.0" diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index c398fe97d2a4..9f2db95721f5 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -63,6 +63,7 @@ object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } paste = "1.0.15" pyo3 = { version = "0.22.0", optional = true } +recursive = { workspace = true } sqlparser = { workspace = true } tokio = { workspace = true } diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index a0ad1e80be9b..c8ec7f18339a 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -17,6 +17,7 @@ //! [`TreeNode`] for visiting and rewriting expression and plan trees +use recursive::recursive; use std::sync::Arc; use crate::Result; @@ -123,6 +124,7 @@ pub trait TreeNode: Sized { /// TreeNodeVisitor::f_up(ChildNode2) /// TreeNodeVisitor::f_up(ParentNode) /// ``` + #[recursive] fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>( &'n self, visitor: &mut V, @@ -172,6 +174,7 @@ pub trait TreeNode: Sized { /// TreeNodeRewriter::f_up(ChildNode2) /// TreeNodeRewriter::f_up(ParentNode) /// ``` + #[recursive] fn rewrite>( self, rewriter: &mut R, @@ -194,6 +197,7 @@ pub trait TreeNode: Sized { &'n self, mut f: F, ) -> Result { + #[recursive] fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result>( node: &'n N, f: &mut F, @@ -228,6 +232,7 @@ pub trait TreeNode: Sized { self, mut f: F, ) -> Result> { + #[recursive] fn transform_down_impl Result>>( node: N, f: &mut F, @@ -251,6 +256,7 @@ pub trait TreeNode: Sized { self, mut f: F, ) -> Result> { + #[recursive] fn transform_up_impl Result>>( node: N, f: &mut F, @@ -365,6 +371,7 @@ pub trait TreeNode: Sized { mut f_down: FD, mut f_up: FU, ) -> Result> { + #[recursive] fn transform_down_up_impl< N: TreeNode, FD: FnMut(N) -> Result>, @@ -2079,4 +2086,17 @@ pub(crate) mod tests { Ok(()) } + + #[test] + fn test_large_tree() { + let mut item = TestTreeNode::new_leaf("initial".to_string()); + for i in 0..3000 { + item = TestTreeNode::new(vec![item], format!("parent-{}", i)); + } + + let mut visitor = + TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); + + item.visit(&mut visitor).unwrap(); + } } diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index d7dc1afe4d50..19cd5ed3158b 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -50,6 +50,7 @@ datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } indexmap = { workspace = true } paste = "^1.0" +recursive = { workspace = true } serde_json = { workspace = true } sqlparser = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index d16fe42098f5..e7dfe8791924 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -42,6 +42,7 @@ use crate::{ LogicalPlan, Partitioning, Projection, RecursiveQuery, Repartition, Sort, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, }; +use recursive::recursive; use std::ops::Deref; use std::sync::Arc; @@ -745,6 +746,7 @@ impl LogicalPlan { /// Visits a plan similarly to [`Self::visit`], including subqueries that /// may appear in expressions such as `IN (SELECT ...)`. + #[recursive] pub fn visit_with_subqueries TreeNodeVisitor<'n, Node = Self>>( &self, visitor: &mut V, @@ -761,6 +763,7 @@ impl LogicalPlan { /// Similarly to [`Self::rewrite`], rewrites this node and its inputs using `f`, /// including subqueries that may appear in expressions such as `IN (SELECT /// ...)`. + #[recursive] pub fn rewrite_with_subqueries>( self, rewriter: &mut R, @@ -779,6 +782,7 @@ impl LogicalPlan { &self, mut f: F, ) -> Result { + #[recursive] fn apply_with_subqueries_impl< F: FnMut(&LogicalPlan) -> Result, >( @@ -814,6 +818,7 @@ impl LogicalPlan { self, mut f: F, ) -> Result> { + #[recursive] fn transform_down_with_subqueries_impl< F: FnMut(LogicalPlan) -> Result>, >( @@ -839,6 +844,7 @@ impl LogicalPlan { self, mut f: F, ) -> Result> { + #[recursive] fn transform_up_with_subqueries_impl< F: FnMut(LogicalPlan) -> Result>, >( @@ -866,6 +872,7 @@ impl LogicalPlan { mut f_down: FD, mut f_up: FU, ) -> Result> { + #[recursive] fn transform_down_up_with_subqueries_impl< FD: FnMut(LogicalPlan) -> Result>, FU: FnMut(LogicalPlan) -> Result>, diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 2ea3ebf337eb..34e35c66107a 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -47,6 +47,7 @@ indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } paste = "1.0.14" +recursive = { workspace = true } regex = { workspace = true } regex-syntax = "0.8.0" diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index fa04835f0967..0b54b302c2df 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -17,6 +17,7 @@ use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; +use recursive::recursive; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, Result}; @@ -128,6 +129,7 @@ fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> { } // Recursively check the unsupported outer references in the sub query plan. +#[recursive] fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> { if !can_contain_outer_ref && inner_plan.contains_outer_reference() { return plan_err!("Accessing outer reference columns is not allowed in the plan"); diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 71327ad3e21d..16a4fa6be38d 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -22,6 +22,7 @@ use std::fmt::Debug; use std::sync::Arc; use crate::{OptimizerConfig, OptimizerRule}; +use recursive::recursive; use crate::optimizer::ApplyOrder; use crate::utils::NamePreserver; @@ -531,6 +532,7 @@ impl OptimizerRule for CommonSubexprEliminate { None } + #[recursive] fn rewrite( &self, plan: LogicalPlan, diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 65ebac2106ad..32b7ce44a63a 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -16,9 +16,9 @@ // under the License. //! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. -use std::sync::Arc; - use crate::{OptimizerConfig, OptimizerRule}; +use recursive::recursive; +use std::sync::Arc; use crate::join_key_set::JoinKeySet; use datafusion_common::tree_node::{Transformed, TreeNode}; @@ -80,6 +80,7 @@ impl OptimizerRule for EliminateCrossJoin { true } + #[recursive] fn rewrite( &self, plan: LogicalPlan, diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 04a523f9b115..b659e477f67e 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -19,11 +19,11 @@ mod required_indices; -use std::collections::HashSet; -use std::sync::Arc; - use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use recursive::recursive; +use std::collections::HashSet; +use std::sync::Arc; use datafusion_common::{ get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, Column, @@ -110,6 +110,7 @@ impl OptimizerRule for OptimizeProjections { /// columns. /// - `Ok(None)`: Signal that the given logical plan did not require any change. /// - `Err(error)`: An error occurred during the optimization process. +#[recursive] fn optimize_projections( plan: LogicalPlan, config: &dyn OptimizerConfig, diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index e7bf4a80fc45..04f01f8badb8 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -40,6 +40,7 @@ datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } itertools = { workspace = true } +recursive = { workspace = true } [dev-dependencies] datafusion-functions-aggregate = { workspace = true } diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index 27870c7865f3..87077183110d 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -16,8 +16,6 @@ // under the License. //! Utilizing exact statistics from sources to avoid scanning data -use std::sync::Arc; - use datafusion_common::config::ConfigOptions; use datafusion_common::scalar::ScalarValue; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -27,6 +25,8 @@ use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; use datafusion_physical_plan::{expressions, ExecutionPlan}; +use recursive::recursive; +use std::sync::Arc; use crate::PhysicalOptimizerRule; @@ -42,6 +42,7 @@ impl AggregateStatistics { } impl PhysicalOptimizerRule for AggregateStatistics { + #[recursive] fn optimize( &self, plan: Arc, diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 1eef1b718ba6..94c3ce97a441 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -48,6 +48,7 @@ datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } indexmap = { workspace = true } log = { workspace = true } +recursive = { workspace = true } regex = { workspace = true } sqlparser = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index b68be90b03e1..72f88abcea99 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -20,6 +20,7 @@ use arrow_schema::TimeUnit; use datafusion_expr::planner::{ PlannerResult, RawBinaryExpr, RawDictionaryExpr, RawFieldAccessExpr, }; +use recursive::recursive; use sqlparser::ast::{ BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField, Expr as SQLExpr, MapEntry, StructField, Subscript, TrimWhereField, Value, @@ -168,16 +169,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Internal implementation. Use /// [`Self::sql_expr_to_logical_expr`] to plan exprs. + #[recursive] fn sql_expr_to_logical_expr_internal( &self, sql: SQLExpr, schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - // NOTE: This function is called recusively, so each match arm body should be as - // small as possible to avoid stack overflows in debug builds. Follow the - // common pattern of extracting into a separate function for non-trivial - // arms. See https://github.com/apache/datafusion/pull/12384 for more context. + // NOTE: This function is called recursively, so each match arm body should be as + // small as possible to decrease stack requirement. + // Follow the common pattern of extracting into a separate function for + // non-trivial arms. See https://github.com/apache/datafusion/pull/12384 for + // more context. match sql { SQLExpr::Value(value) => { self.parse_value(value, planner_context.prepare_param_data_types()) diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 7dc15de7ad71..1cf090aa64aa 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -133,8 +133,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } - // IMPORTANT: Keep sql_array_literal's function body small to prevent stack overflow - // This function is recursively called, potentially leading to deep call stacks. pub(super) fn sql_array_literal( &self, elements: Vec, diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 1ef009132f9e..740f9ad3b42c 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -59,7 +59,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.select_into(plan, select_into) } other => { + // The functions called from `set_expr_to_plan()` need more than 128KB + // stack in debug builds as investigated in: + // https://github.com/apache/datafusion/pull/13310#discussion_r1836813902 + let min_stack_size = recursive::get_minimum_stack_size(); + recursive::set_minimum_stack_size(256 * 1024); let plan = self.set_expr_to_plan(other, planner_context)?; + recursive::set_minimum_stack_size(min_stack_size); let oby_exprs = to_order_by_exprs(query.order_by)?; let order_by_rex = self.order_by_to_sort_expr( oby_exprs, diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index 248aad846996..e56ebb4d323f 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -18,9 +18,11 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{not_impl_err, Result}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; +use recursive::recursive; use sqlparser::ast::{SetExpr, SetOperator, SetQuantifier}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { + #[recursive] pub(super) fn set_expr_to_plan( &self, set_expr: SetExpr, From 382ba2b1605d244a10db21f4527e11570caa4323 Mon Sep 17 00:00:00 2001 From: NoeB Date: Mon, 11 Nov 2024 18:54:19 +0100 Subject: [PATCH 03/17] minor(docs): Correct array_prepend docs (#13362) * minor(docs): Correct array_prepend docs * formatted docs with prettier --- docs/source/user-guide/expressions.md | 2 +- docs/source/user-guide/sql/scalar_functions.md | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index ababb001f5c5..03ab86eeb813 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -228,7 +228,7 @@ select log(-1), log(0), sqrt(-1); | array_pop_back(array) | Returns the array without the last element. `array_pop_back([1, 2, 3]) -> [1, 2]` | | array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | | array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | -| array_prepend(array, element) | Prepends an element to the beginning of an array. `array_prepend(1, [2, 3, 4]) -> [1, 2, 3, 4]` | +| array_prepend(element, array) | Prepends an element to the beginning of an array. `array_prepend(1, [2, 3, 4]) -> [1, 2, 3, 4]` | | array_repeat(element, count) | Returns an array containing element `count` times. `array_repeat(1, 3) -> [1, 1, 1]` | | array_remove(array, element) | Removes the first element from the array equal to the given value. `array_remove([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 2, 3, 2, 1, 4]` | | array_remove_n(array, element, max) | Removes the first `max` elements from the array equal to the given value. `array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2) -> [1, 3, 2, 1, 4]` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 232efb02d423..490462909b59 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3081,23 +3081,23 @@ array_position(array, element, index) ### `array_prepend` -Appends an element to the end of an array. +Prepends an element to the beginning of an array. ``` -array_append(array, element) +array_prepend(element, array) ``` #### Arguments -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **element**: Element to append to the array. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example ```sql -> select array_append([1, 2, 3], 4); +> select array_prepend(1, [2, 3, 4]); +--------------------------------------+ -| array_append(List([1,2,3]),Int64(4)) | +| array_prepend(Int64(1), List([2,3,4])) | +--------------------------------------+ | [1, 2, 3, 4] | +--------------------------------------+ From cadeb53e4ad36a6a12238e3acc833c6405c75dd2 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Mon, 11 Nov 2024 20:28:48 -0500 Subject: [PATCH 04/17] fix ci (#13367) --- datafusion/functions-nested/src/concat.rs | 4 +++- docs/source/user-guide/sql/scalar_functions.md | 12 ++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index 1bdcf74aee2a..4aa6bb5da9b2 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -195,8 +195,10 @@ impl ScalarUDFImpl for ArrayPrepend { } } +static DOCUMENTATION_PREPEND: OnceLock = OnceLock::new(); + fn get_array_prepend_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { + DOCUMENTATION_PREPEND.get_or_init(|| { Documentation::builder() .with_doc_section(DOC_SECTION_ARRAY) .with_description( diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 490462909b59..e9cd2bba7d11 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3089,18 +3089,18 @@ array_prepend(element, array) #### Arguments -- **element**: Element to append to the array. +- **element**: Element to prepend to the array. - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example ```sql > select array_prepend(1, [2, 3, 4]); -+--------------------------------------+ -| array_prepend(Int64(1), List([2,3,4])) | -+--------------------------------------+ -| [1, 2, 3, 4] | -+--------------------------------------+ ++---------------------------------------+ +| array_prepend(Int64(1),List([2,3,4])) | ++---------------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------------+ ``` #### Aliases From cb7ec85ee3a41e368563e60567104a77240fc7b4 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 12 Nov 2024 09:34:25 +0800 Subject: [PATCH 05/17] introduce information_schema.parameters table (#13341) --- .../src/catalog_common/information_schema.rs | 243 +++++++++++++++++- .../test_files/information_schema.slt | 57 ++++ .../information_schema_multiple_catalogs.slt | 4 + .../information_schema_table_types.slt | 1 + 4 files changed, 302 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/catalog_common/information_schema.rs b/datafusion/core/src/catalog_common/information_schema.rs index 53c1a8b11e1c..72f842d3675e 100644 --- a/datafusion/core/src/catalog_common/information_schema.rs +++ b/datafusion/core/src/catalog_common/information_schema.rs @@ -38,7 +38,7 @@ use arrow_array::builder::BooleanBuilder; use async_trait::async_trait; use datafusion_common::error::Result; use datafusion_common::DataFusionError; -use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, WindowUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::{any::Any, sync::Arc}; @@ -50,10 +50,18 @@ pub(crate) const COLUMNS: &str = "columns"; pub(crate) const DF_SETTINGS: &str = "df_settings"; pub(crate) const SCHEMATA: &str = "schemata"; pub(crate) const ROUTINES: &str = "routines"; +pub(crate) const PARAMETERS: &str = "parameters"; /// All information schema tables -pub const INFORMATION_SCHEMA_TABLES: &[&str] = - &[TABLES, VIEWS, COLUMNS, DF_SETTINGS, SCHEMATA, ROUTINES]; +pub const INFORMATION_SCHEMA_TABLES: &[&str] = &[ + TABLES, + VIEWS, + COLUMNS, + DF_SETTINGS, + SCHEMATA, + ROUTINES, + PARAMETERS, +]; /// Implements the `information_schema` virtual schema and tables /// @@ -286,6 +294,102 @@ impl InformationSchemaConfig { fn is_deterministic(signature: &Signature) -> bool { signature.volatility == Volatility::Immutable } + fn make_parameters( + &self, + udfs: &HashMap>, + udafs: &HashMap>, + udwfs: &HashMap>, + config_options: &ConfigOptions, + builder: &mut InformationSchemaParametersBuilder, + ) -> Result<()> { + let catalog_name = &config_options.catalog.default_catalog; + let schema_name = &config_options.catalog.default_schema; + let mut add_parameters = |func_name: &str, + args: Option<&Vec<(String, String)>>, + arg_types: Vec, + return_type: Option, + is_variadic: bool| { + for (position, type_name) in arg_types.iter().enumerate() { + let param_name = + args.and_then(|a| a.get(position).map(|arg| arg.0.as_str())); + builder.add_parameter( + catalog_name, + schema_name, + func_name, + position as u64 + 1, + "IN", + param_name, + type_name, + None::<&str>, + is_variadic, + ); + } + if let Some(return_type) = return_type { + builder.add_parameter( + catalog_name, + schema_name, + func_name, + 1, + "OUT", + None::<&str>, + return_type.as_str(), + None::<&str>, + false, + ); + } + }; + + for (func_name, udf) in udfs { + let args = udf.documentation().and_then(|d| d.arguments.clone()); + let combinations = get_udf_args_and_return_types(udf)?; + for (arg_types, return_type) in combinations { + add_parameters( + func_name, + args.as_ref(), + arg_types, + return_type, + Self::is_variadic(udf.signature()), + ); + } + } + + for (func_name, udaf) in udafs { + let args = udaf.documentation().and_then(|d| d.arguments.clone()); + let combinations = get_udaf_args_and_return_types(udaf)?; + for (arg_types, return_type) in combinations { + add_parameters( + func_name, + args.as_ref(), + arg_types, + return_type, + Self::is_variadic(udaf.signature()), + ); + } + } + + for (func_name, udwf) in udwfs { + let args = udwf.documentation().and_then(|d| d.arguments.clone()); + let combinations = get_udwf_args_and_return_types(udwf)?; + for (arg_types, return_type) in combinations { + add_parameters( + func_name, + args.as_ref(), + arg_types, + return_type, + Self::is_variadic(udwf.signature()), + ); + } + } + + Ok(()) + } + + fn is_variadic(signature: &Signature) -> bool { + matches!( + signature.type_signature, + TypeSignature::Variadic(_) | TypeSignature::VariadicAny + ) + } } /// get the arguments and return types of a UDF @@ -384,6 +488,7 @@ impl SchemaProvider for InformationSchemaProvider { DF_SETTINGS => Arc::new(InformationSchemaDfSettings::new(config)), SCHEMATA => Arc::new(InformationSchemata::new(config)), ROUTINES => Arc::new(InformationSchemaRoutines::new(config)), + PARAMETERS => Arc::new(InformationSchemaParameters::new(config)), _ => return Ok(None), }; @@ -1098,3 +1203,135 @@ impl PartitionStream for InformationSchemaRoutines { )) } } + +#[derive(Debug)] +struct InformationSchemaParameters { + schema: SchemaRef, + config: InformationSchemaConfig, +} + +impl InformationSchemaParameters { + fn new(config: InformationSchemaConfig) -> Self { + let schema = Arc::new(Schema::new(vec![ + Field::new("specific_catalog", DataType::Utf8, false), + Field::new("specific_schema", DataType::Utf8, false), + Field::new("specific_name", DataType::Utf8, false), + Field::new("ordinal_position", DataType::UInt64, false), + Field::new("parameter_mode", DataType::Utf8, false), + Field::new("parameter_name", DataType::Utf8, true), + Field::new("data_type", DataType::Utf8, false), + Field::new("parameter_default", DataType::Utf8, true), + Field::new("is_variadic", DataType::Boolean, false), + ])); + + Self { schema, config } + } + + fn builder(&self) -> InformationSchemaParametersBuilder { + InformationSchemaParametersBuilder { + schema: self.schema.clone(), + specific_catalog: StringBuilder::new(), + specific_schema: StringBuilder::new(), + specific_name: StringBuilder::new(), + ordinal_position: UInt64Builder::new(), + parameter_mode: StringBuilder::new(), + parameter_name: StringBuilder::new(), + data_type: StringBuilder::new(), + parameter_default: StringBuilder::new(), + is_variadic: BooleanBuilder::new(), + inserted: HashSet::new(), + } + } +} + +struct InformationSchemaParametersBuilder { + schema: SchemaRef, + specific_catalog: StringBuilder, + specific_schema: StringBuilder, + specific_name: StringBuilder, + ordinal_position: UInt64Builder, + parameter_mode: StringBuilder, + parameter_name: StringBuilder, + data_type: StringBuilder, + parameter_default: StringBuilder, + is_variadic: BooleanBuilder, + // use HashSet to avoid duplicate rows. The key is (specific_name, ordinal_position, parameter_mode, data_type) + inserted: HashSet<(String, u64, String, String)>, +} + +impl InformationSchemaParametersBuilder { + #[allow(clippy::too_many_arguments)] + fn add_parameter( + &mut self, + specific_catalog: impl AsRef, + specific_schema: impl AsRef, + specific_name: impl AsRef, + ordinal_position: u64, + parameter_mode: impl AsRef, + parameter_name: Option>, + data_type: impl AsRef, + parameter_default: Option>, + is_variadic: bool, + ) { + let key = ( + specific_name.as_ref().to_string(), + ordinal_position, + parameter_mode.as_ref().to_string(), + data_type.as_ref().to_string(), + ); + if self.inserted.insert(key) { + self.specific_catalog + .append_value(specific_catalog.as_ref()); + self.specific_schema.append_value(specific_schema.as_ref()); + self.specific_name.append_value(specific_name.as_ref()); + self.ordinal_position.append_value(ordinal_position); + self.parameter_mode.append_value(parameter_mode.as_ref()); + self.parameter_name.append_option(parameter_name.as_ref()); + self.data_type.append_value(data_type.as_ref()); + self.parameter_default.append_option(parameter_default); + self.is_variadic.append_value(is_variadic); + } + } + + fn finish(&mut self) -> RecordBatch { + RecordBatch::try_new( + self.schema.clone(), + vec![ + Arc::new(self.specific_catalog.finish()), + Arc::new(self.specific_schema.finish()), + Arc::new(self.specific_name.finish()), + Arc::new(self.ordinal_position.finish()), + Arc::new(self.parameter_mode.finish()), + Arc::new(self.parameter_name.finish()), + Arc::new(self.data_type.finish()), + Arc::new(self.parameter_default.finish()), + Arc::new(self.is_variadic.finish()), + ], + ) + .unwrap() + } +} + +impl PartitionStream for InformationSchemaParameters { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute(&self, ctx: Arc) -> SendableRecordBatchStream { + let config = self.config.clone(); + let mut builder = self.builder(); + Box::pin(RecordBatchStreamAdapter::new( + self.schema.clone(), + futures::stream::once(async move { + config.make_parameters( + ctx.scalar_functions(), + ctx.aggregate_functions(), + ctx.window_functions(), + ctx.session_config().options(), + &mut builder, + )?; + Ok(builder.finish()) + }), + )) + } +} diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index dd5156cb53cc..4d51a61c8a52 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -39,6 +39,7 @@ SELECT * from information_schema.tables; ---- datafusion information_schema columns VIEW datafusion information_schema df_settings VIEW +datafusion information_schema parameters VIEW datafusion information_schema routines VIEW datafusion information_schema schemata VIEW datafusion information_schema tables VIEW @@ -84,6 +85,7 @@ SELECT * from information_schema.tables; ---- datafusion information_schema columns VIEW datafusion information_schema df_settings VIEW +datafusion information_schema parameters VIEW datafusion information_schema routines VIEW datafusion information_schema schemata VIEW datafusion information_schema tables VIEW @@ -99,6 +101,7 @@ SELECT * from information_schema.tables; ---- datafusion information_schema columns VIEW datafusion information_schema df_settings VIEW +datafusion information_schema parameters VIEW datafusion information_schema routines VIEW datafusion information_schema schemata VIEW datafusion information_schema tables VIEW @@ -111,6 +114,7 @@ SELECT * from information_schema.tables WHERE tables.table_schema='information_s ---- datafusion information_schema columns VIEW datafusion information_schema df_settings VIEW +datafusion information_schema parameters VIEW datafusion information_schema routines VIEW datafusion information_schema schemata VIEW datafusion information_schema tables VIEW @@ -121,6 +125,7 @@ SELECT * from information_schema.tables WHERE information_schema.tables.table_sc ---- datafusion information_schema columns VIEW datafusion information_schema df_settings VIEW +datafusion information_schema parameters VIEW datafusion information_schema routines VIEW datafusion information_schema schemata VIEW datafusion information_schema tables VIEW @@ -131,6 +136,7 @@ SELECT * from information_schema.tables WHERE datafusion.information_schema.tabl ---- datafusion information_schema columns VIEW datafusion information_schema df_settings VIEW +datafusion information_schema parameters VIEW datafusion information_schema routines VIEW datafusion information_schema schemata VIEW datafusion information_schema tables VIEW @@ -454,6 +460,7 @@ SHOW TABLES ---- datafusion information_schema columns VIEW datafusion information_schema df_settings VIEW +datafusion information_schema parameters VIEW datafusion information_schema routines VIEW datafusion information_schema schemata VIEW datafusion information_schema tables VIEW @@ -636,3 +643,53 @@ query B select is_deterministic from information_schema.routines where routine_name = 'now'; ---- false + +# test every function type are included in the result +query TTTITTTTB rowsort +select * from information_schema.parameters where specific_name = 'date_trunc' OR specific_name = 'string_agg' OR specific_name = 'rank'; +---- +datafusion public date_trunc 1 IN precision Utf8 NULL false +datafusion public date_trunc 1 IN precision Utf8View NULL false +datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false +datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false +datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false +datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false +datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false +datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false +datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false +datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false +datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false +datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false +datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false +datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false +datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false +datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false +datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false +datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false +datafusion public string_agg 1 IN expression LargeUtf8 NULL false +datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false +datafusion public string_agg 2 IN delimiter LargeUtf8 NULL false +datafusion public string_agg 2 IN delimiter Null NULL false +datafusion public string_agg 2 IN delimiter Utf8 NULL false + +# test variable length arguments +query TTTB rowsort +select specific_name, data_type, parameter_mode, is_variadic from information_schema.parameters where specific_name = 'concat'; +---- +concat LargeUtf8 IN true +concat LargeUtf8 OUT false +concat Utf8 IN true +concat Utf8 OUT false +concat Utf8View IN true +concat Utf8View OUT false + +# test ceorcion signature +query TTIT rowsort +select specific_name, data_type, ordinal_position, parameter_mode from information_schema.parameters where specific_name = 'repeat'; +---- +repeat Int64 2 IN +repeat LargeUtf8 1 IN +repeat LargeUtf8 1 OUT +repeat Utf8 1 IN +repeat Utf8 1 OUT +repeat Utf8View 1 IN diff --git a/datafusion/sqllogictest/test_files/information_schema_multiple_catalogs.slt b/datafusion/sqllogictest/test_files/information_schema_multiple_catalogs.slt index 988a4275c6e3..0594aa7cfca8 100644 --- a/datafusion/sqllogictest/test_files/information_schema_multiple_catalogs.slt +++ b/datafusion/sqllogictest/test_files/information_schema_multiple_catalogs.slt @@ -35,6 +35,7 @@ SELECT * from information_schema.tables; ---- datafusion information_schema columns VIEW datafusion information_schema df_settings VIEW +datafusion information_schema parameters VIEW datafusion information_schema routines VIEW datafusion information_schema schemata VIEW datafusion information_schema tables VIEW @@ -81,12 +82,14 @@ SELECT * from information_schema.tables; ---- datafusion information_schema columns VIEW datafusion information_schema df_settings VIEW +datafusion information_schema parameters VIEW datafusion information_schema routines VIEW datafusion information_schema schemata VIEW datafusion information_schema tables VIEW datafusion information_schema views VIEW my_catalog information_schema columns VIEW my_catalog information_schema df_settings VIEW +my_catalog information_schema parameters VIEW my_catalog information_schema routines VIEW my_catalog information_schema schemata VIEW my_catalog information_schema tables VIEW @@ -95,6 +98,7 @@ my_catalog my_schema t1 BASE TABLE my_catalog my_schema t2 BASE TABLE my_other_catalog information_schema columns VIEW my_other_catalog information_schema df_settings VIEW +my_other_catalog information_schema parameters VIEW my_other_catalog information_schema routines VIEW my_other_catalog information_schema schemata VIEW my_other_catalog information_schema tables VIEW diff --git a/datafusion/sqllogictest/test_files/information_schema_table_types.slt b/datafusion/sqllogictest/test_files/information_schema_table_types.slt index 8a1a94c6a026..5650d537b06d 100644 --- a/datafusion/sqllogictest/test_files/information_schema_table_types.slt +++ b/datafusion/sqllogictest/test_files/information_schema_table_types.slt @@ -36,6 +36,7 @@ SELECT * from information_schema.tables; ---- datafusion information_schema columns VIEW datafusion information_schema df_settings VIEW +datafusion information_schema parameters VIEW datafusion information_schema routines VIEW datafusion information_schema schemata VIEW datafusion information_schema tables VIEW From cd69e373320a1682dfa7c32d14d5294add8b5fee Mon Sep 17 00:00:00 2001 From: Leonardo Yvens Date: Tue, 12 Nov 2024 04:42:56 +0000 Subject: [PATCH 06/17] support recursive CTEs logical plans in datafusion-proto (#13314) * support LogicaPlan::RecursiveQuery in datafusion-proto * fixed and failing test roundtrip_recursive_query * fix rebase artifact * add node for CteWorkTableScan in datafusion-proto * Use Arc::clone --------- Co-authored-by: jonahgao --- .../core/src/datasource/cte_worktable.rs | 12 +- datafusion/proto/proto/datafusion.proto | 14 + datafusion/proto/src/generated/pbjson.rs | 281 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 26 +- datafusion/proto/src/logical_plan/mod.rs | 82 ++++- .../tests/cases/roundtrip_logical_plan.rs | 28 ++ 6 files changed, 434 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/datasource/cte_worktable.rs b/datafusion/core/src/datasource/cte_worktable.rs index 23f57b12ae08..9504721cecd6 100644 --- a/datafusion/core/src/datasource/cte_worktable.rs +++ b/datafusion/core/src/datasource/cte_worktable.rs @@ -39,8 +39,6 @@ use crate::datasource::{TableProvider, TableType}; #[derive(Debug)] pub struct CteWorkTable { /// The name of the CTE work table - // WIP, see https://github.com/apache/datafusion/issues/462 - #[allow(dead_code)] name: String, /// This schema must be shared across both the static and recursive terms of a recursive query table_schema: SchemaRef, @@ -56,6 +54,16 @@ impl CteWorkTable { table_schema, } } + + /// The user-provided name of the CTE + pub fn name(&self) -> &str { + &self.name + } + + /// The schema of the recursive term of the query + pub fn schema(&self) -> SchemaRef { + Arc::clone(&self.table_schema) + } } #[async_trait] diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d6fa129edc3f..b9a1cff94d05 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -59,6 +59,8 @@ message LogicalPlanNode { DistinctOnNode distinct_on = 28; CopyToNode copy_to = 29; UnnestNode unnest = 30; + RecursiveQueryNode recursive_query = 31; + CteWorkTableScanNode cte_work_table_scan = 32; } } @@ -1249,3 +1251,15 @@ message PartitionStats { int64 num_bytes = 3; repeated datafusion_common.ColumnStats column_stats = 4; } + +message RecursiveQueryNode { + string name = 1; + LogicalPlanNode static_term = 2; + LogicalPlanNode recursive_term = 3; + bool is_distinct = 4; +} + +message CteWorkTableScanNode { + string name = 1; + datafusion_common.Schema schema = 2; +} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 16f14d9ddf61..52ba1ea8aa79 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4153,6 +4153,114 @@ impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { deserializer.deserialize_struct("datafusion.CsvSinkExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CteWorkTableScanNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.name.is_empty() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CteWorkTableScanNode", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CteWorkTableScanNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + "schema", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + Schema, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + "schema" => Ok(GeneratedField::Schema), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CteWorkTableScanNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CteWorkTableScanNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + let mut schema__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + } + } + Ok(CteWorkTableScanNode { + name: name__.unwrap_or_default(), + schema: schema__, + }) + } + } + deserializer.deserialize_struct("datafusion.CteWorkTableScanNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CubeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -10602,6 +10710,12 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::Unnest(v) => { struct_ser.serialize_field("unnest", v)?; } + logical_plan_node::LogicalPlanType::RecursiveQuery(v) => { + struct_ser.serialize_field("recursiveQuery", v)?; + } + logical_plan_node::LogicalPlanType::CteWorkTableScan(v) => { + struct_ser.serialize_field("cteWorkTableScan", v)?; + } } } struct_ser.end() @@ -10656,6 +10770,10 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "copy_to", "copyTo", "unnest", + "recursive_query", + "recursiveQuery", + "cte_work_table_scan", + "cteWorkTableScan", ]; #[allow(clippy::enum_variant_names)] @@ -10689,6 +10807,8 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { DistinctOn, CopyTo, Unnest, + RecursiveQuery, + CteWorkTableScan, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10739,6 +10859,8 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "distinctOn" | "distinct_on" => Ok(GeneratedField::DistinctOn), "copyTo" | "copy_to" => Ok(GeneratedField::CopyTo), "unnest" => Ok(GeneratedField::Unnest), + "recursiveQuery" | "recursive_query" => Ok(GeneratedField::RecursiveQuery), + "cteWorkTableScan" | "cte_work_table_scan" => Ok(GeneratedField::CteWorkTableScan), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10962,6 +11084,20 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("unnest")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Unnest) +; + } + GeneratedField::RecursiveQuery => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("recursiveQuery")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::RecursiveQuery) +; + } + GeneratedField::CteWorkTableScan => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("cteWorkTableScan")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CteWorkTableScan) ; } } @@ -17486,6 +17622,151 @@ impl<'de> serde::Deserialize<'de> for RecursionUnnestOption { deserializer.deserialize_struct("datafusion.RecursionUnnestOption", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for RecursiveQueryNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.name.is_empty() { + len += 1; + } + if self.static_term.is_some() { + len += 1; + } + if self.recursive_term.is_some() { + len += 1; + } + if self.is_distinct { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.RecursiveQueryNode", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.static_term.as_ref() { + struct_ser.serialize_field("staticTerm", v)?; + } + if let Some(v) = self.recursive_term.as_ref() { + struct_ser.serialize_field("recursiveTerm", v)?; + } + if self.is_distinct { + struct_ser.serialize_field("isDistinct", &self.is_distinct)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RecursiveQueryNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + "static_term", + "staticTerm", + "recursive_term", + "recursiveTerm", + "is_distinct", + "isDistinct", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + StaticTerm, + RecursiveTerm, + IsDistinct, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + "staticTerm" | "static_term" => Ok(GeneratedField::StaticTerm), + "recursiveTerm" | "recursive_term" => Ok(GeneratedField::RecursiveTerm), + "isDistinct" | "is_distinct" => Ok(GeneratedField::IsDistinct), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RecursiveQueryNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.RecursiveQueryNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + let mut static_term__ = None; + let mut recursive_term__ = None; + let mut is_distinct__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value()?); + } + GeneratedField::StaticTerm => { + if static_term__.is_some() { + return Err(serde::de::Error::duplicate_field("staticTerm")); + } + static_term__ = map_.next_value()?; + } + GeneratedField::RecursiveTerm => { + if recursive_term__.is_some() { + return Err(serde::de::Error::duplicate_field("recursiveTerm")); + } + recursive_term__ = map_.next_value()?; + } + GeneratedField::IsDistinct => { + if is_distinct__.is_some() { + return Err(serde::de::Error::duplicate_field("isDistinct")); + } + is_distinct__ = Some(map_.next_value()?); + } + } + } + Ok(RecursiveQueryNode { + name: name__.unwrap_or_default(), + static_term: static_term__, + recursive_term: recursive_term__, + is_distinct: is_distinct__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.RecursiveQueryNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for RepartitionExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 59a90eb31ade..c7f5606049c0 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -5,7 +5,7 @@ pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32" )] pub logical_plan_type: ::core::option::Option, } @@ -71,6 +71,10 @@ pub mod logical_plan_node { CopyTo(::prost::alloc::boxed::Box), #[prost(message, tag = "30")] Unnest(::prost::alloc::boxed::Box), + #[prost(message, tag = "31")] + RecursiveQuery(::prost::alloc::boxed::Box), + #[prost(message, tag = "32")] + CteWorkTableScan(super::CteWorkTableScanNode), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1811,6 +1815,26 @@ pub struct PartitionStats { #[prost(message, repeated, tag = "4")] pub column_stats: ::prost::alloc::vec::Vec, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RecursiveQueryNode { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, boxed, tag = "2")] + pub static_term: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "3")] + pub recursive_term: ::core::option::Option< + ::prost::alloc::boxed::Box, + >, + #[prost(bool, tag = "4")] + pub is_distinct: bool, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CteWorkTableScanNode { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub schema: ::core::option::Option, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum BuiltInWindowFunction { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index a55fecec98f6..50636048ebc9 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -21,8 +21,8 @@ use std::sync::Arc; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::{ - ColumnUnnestListItem, ColumnUnnestListRecursion, CustomTableScanNode, - SortExprNodeCollection, + ColumnUnnestListItem, ColumnUnnestListRecursion, CteWorkTableScanNode, + CustomTableScanNode, SortExprNodeCollection, }; use crate::{ convert_required, into_required, @@ -34,6 +34,7 @@ use crate::{ use crate::protobuf::{proto_error, ToProtoError}; use arrow::datatypes::{DataType, Schema, SchemaRef}; +use datafusion::datasource::cte_worktable::CteWorkTable; #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::{ @@ -68,7 +69,9 @@ use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, Statement, WindowUDF, }; -use datafusion_expr::{AggregateUDF, ColumnUnnestList, FetchType, SkipType, Unnest}; +use datafusion_expr::{ + AggregateUDF, ColumnUnnestList, FetchType, RecursiveQuery, SkipType, Unnest, +}; use self::to_proto::{serialize_expr, serialize_exprs}; use crate::logical_plan::to_proto::serialize_sorts; @@ -905,6 +908,41 @@ impl AsLogicalPlan for LogicalPlanNode { options: into_required!(unnest.options)?, })) } + LogicalPlanType::RecursiveQuery(recursive_query_node) => { + let static_term = recursive_query_node + .static_term + .as_ref() + .ok_or_else(|| DataFusionError::Internal(String::from( + "Protobuf deserialization error, RecursiveQueryNode was missing required field static_term.", + )))? + .try_into_logical_plan(ctx, extension_codec)?; + + let recursive_term = recursive_query_node + .recursive_term + .as_ref() + .ok_or_else(|| DataFusionError::Internal(String::from( + "Protobuf deserialization error, RecursiveQueryNode was missing required field recursive_term.", + )))? + .try_into_logical_plan(ctx, extension_codec)?; + + Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + name: recursive_query_node.name.clone(), + static_term: Arc::new(static_term), + recursive_term: Arc::new(recursive_term), + is_distinct: recursive_query_node.is_distinct, + })) + } + LogicalPlanType::CteWorkTableScan(cte_work_table_scan_node) => { + let CteWorkTableScanNode { name, schema } = cte_work_table_scan_node; + let schema = convert_required!(*schema)?; + let cte_work_table = CteWorkTable::new(name.as_str(), Arc::new(schema)); + LogicalPlanBuilder::scan( + name.as_str(), + provider_as_source(Arc::new(cte_work_table)), + None, + )? + .build() + } } } @@ -1061,6 +1099,20 @@ impl AsLogicalPlan for LogicalPlanNode { }, ))), }) + } else if let Some(cte_work_table) = source.downcast_ref::() + { + let name = cte_work_table.name().to_string(); + let schema = cte_work_table.schema(); + let schema: protobuf::Schema = schema.as_ref().try_into()?; + + Ok(LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CteWorkTableScan( + protobuf::CteWorkTableScanNode { + name, + schema: Some(schema), + }, + )), + }) } else { let mut bytes = vec![]; extension_codec @@ -1630,9 +1682,27 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), - LogicalPlan::RecursiveQuery(_) => Err(proto_error( - "LogicalPlan serde is not yet implemented for RecursiveQuery", - )), + LogicalPlan::RecursiveQuery(recursive) => { + let static_term = LogicalPlanNode::try_from_logical_plan( + recursive.static_term.as_ref(), + extension_codec, + )?; + let recursive_term = LogicalPlanNode::try_from_logical_plan( + recursive.recursive_term.as_ref(), + extension_codec, + )?; + + Ok(LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::RecursiveQuery(Box::new( + protobuf::RecursiveQueryNode { + name: recursive.name.clone(), + static_term: Some(Box::new(static_term)), + recursive_term: Some(Box::new(recursive_term)), + is_distinct: recursive.is_distinct, + }, + ))), + }) + } } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ecfbaee23537..8445cdc761ed 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -23,6 +23,7 @@ use arrow::datatypes::{ IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, DECIMAL256_MAX_PRECISION, }; +use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::file_format::json::JsonFormatFactory; use datafusion_common::parsers::CompressionTypeVariant; use prost::Message; @@ -2523,3 +2524,30 @@ fn roundtrip_window() { roundtrip_expr_test(test_expr6, ctx.clone()); roundtrip_expr_test(text_expr7, ctx); } + +#[tokio::test] +async fn roundtrip_recursive_query() { + let query = "WITH RECURSIVE cte AS ( + SELECT 1 as n + UNION ALL + SELECT n + 1 FROM cte WHERE n < 5 + ) + SELECT * FROM cte;"; + + let ctx = SessionContext::new(); + let dataframe = ctx.sql(query).await.unwrap(); + let plan = dataframe.logical_plan().clone(); + let output = dataframe.collect().await.unwrap(); + let bytes = logical_plan_to_bytes(&plan).unwrap(); + + let ctx = SessionContext::new(); + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).unwrap(); + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + let dataframe = ctx.execute_logical_plan(logical_round_trip).await.unwrap(); + let output_round_trip = dataframe.collect().await.unwrap(); + + assert_eq!( + format!("{}", pretty_format_batches(&output).unwrap()), + format!("{}", pretty_format_batches(&output_round_trip).unwrap()) + ); +} From 5cc12235a73a78922dfd8c9d39304d249273ee44 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:37:43 +0100 Subject: [PATCH 07/17] Update substrait requirement from 0.46 to 0.47 (#13374) Updates the requirements on [substrait](https://github.com/substrait-io/substrait-rs) to permit the latest version. - [Release notes](https://github.com/substrait-io/substrait-rs/releases) - [Changelog](https://github.com/substrait-io/substrait-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/substrait-io/substrait-rs/compare/v0.46.0...v0.47.1) --- updated-dependencies: - dependency-name: substrait dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/substrait/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 9432c798e6d6..192fe26d6cef 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -41,7 +41,7 @@ object_store = { workspace = true } pbjson-types = "0.7" # TODO use workspace version prost = "0.13" -substrait = { version = "0.46", features = ["serde"] } +substrait = { version = "0.47", features = ["serde"] } url = { workspace = true } [dev-dependencies] From 1bbe13f013de379328eee4b124ee02c55d08b823 Mon Sep 17 00:00:00 2001 From: Namgung Chan <33323415+getChan@users.noreply.github.com> Date: Tue, 12 Nov 2024 22:52:37 +0900 Subject: [PATCH 08/17] Enable `clone_on_ref_ptr` clippy lint on core crate (#13338) * Enable clone_on_ref_ptr clippy lint on physical-expr-common crate * cargo fmt * remove explicit type * information_schema * listing_schema * memory * avro_to_arrow.reader * cte_worktable * clone_on_ref_ptr * type infer * fmt * except test * after rebase main brach --- .../src/catalog_common/information_schema.rs | 44 ++++++------- .../core/src/catalog_common/listing_schema.rs | 5 +- datafusion/core/src/catalog_common/memory.rs | 6 +- .../src/datasource/avro_to_arrow/reader.rs | 4 +- .../core/src/datasource/cte_worktable.rs | 4 +- .../src/datasource/default_table_source.rs | 2 +- datafusion/core/src/datasource/empty.rs | 2 +- .../core/src/datasource/file_format/arrow.rs | 6 +- .../core/src/datasource/file_format/csv.rs | 2 +- .../core/src/datasource/file_format/json.rs | 2 +- .../core/src/datasource/file_format/mod.rs | 10 +-- .../src/datasource/file_format/parquet.rs | 40 +++++++----- .../src/datasource/file_format/write/demux.rs | 4 +- .../file_format/write/orchestration.rs | 5 +- .../core/src/datasource/listing/table.rs | 4 +- .../src/datasource/listing_table_factory.rs | 2 +- datafusion/core/src/datasource/memory.rs | 12 ++-- .../datasource/physical_plan/arrow_file.rs | 6 +- .../core/src/datasource/physical_plan/avro.rs | 8 +-- .../core/src/datasource/physical_plan/csv.rs | 10 +-- .../datasource/physical_plan/file_stream.rs | 5 +- .../core/src/datasource/physical_plan/json.rs | 10 +-- .../datasource/physical_plan/parquet/mod.rs | 6 +- .../physical_plan/parquet/opener.rs | 4 +- .../physical_plan/parquet/page_filter.rs | 18 +++--- .../physical_plan/parquet/row_filter.rs | 2 +- .../physical_plan/parquet/writer.rs | 6 +- .../core/src/datasource/schema_adapter.rs | 6 +- datafusion/core/src/datasource/stream.rs | 18 +++--- datafusion/core/src/datasource/streaming.rs | 4 +- datafusion/core/src/execution/context/mod.rs | 4 +- .../core/src/execution/session_state.rs | 29 +++++---- .../src/execution/session_state_defaults.rs | 9 ++- datafusion/core/src/lib.rs | 2 + .../enforce_distribution.rs | 63 ++++++++++--------- .../src/physical_optimizer/enforce_sorting.rs | 29 +++++---- .../src/physical_optimizer/join_selection.rs | 8 +-- .../physical_optimizer/projection_pushdown.rs | 28 ++++----- .../core/src/physical_optimizer/pruning.rs | 34 +++++----- .../replace_with_order_preserving_variants.rs | 8 +-- .../src/physical_optimizer/sanity_checker.rs | 2 +- .../src/physical_optimizer/sort_pushdown.rs | 4 +- .../core/src/physical_optimizer/utils.rs | 2 +- datafusion/core/src/physical_planner.rs | 12 ++-- datafusion/core/src/test_util/mod.rs | 4 +- datafusion/core/src/test_util/parquet.rs | 13 ++-- 46 files changed, 270 insertions(+), 238 deletions(-) diff --git a/datafusion/core/src/catalog_common/information_schema.rs b/datafusion/core/src/catalog_common/information_schema.rs index 72f842d3675e..1d4a3c15f7ca 100644 --- a/datafusion/core/src/catalog_common/information_schema.rs +++ b/datafusion/core/src/catalog_common/information_schema.rs @@ -493,7 +493,7 @@ impl SchemaProvider for InformationSchemaProvider { }; Ok(Some(Arc::new( - StreamingTable::try_new(table.schema().clone(), vec![table]).unwrap(), + StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), ))) } @@ -526,7 +526,7 @@ impl InformationSchemaTables { schema_names: StringBuilder::new(), table_names: StringBuilder::new(), table_types: StringBuilder::new(), - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), } } } @@ -540,7 +540,7 @@ impl PartitionStream for InformationSchemaTables { let mut builder = self.builder(); let config = self.config.clone(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { config.make_tables(&mut builder).await?; @@ -582,7 +582,7 @@ impl InformationSchemaTablesBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.catalog_names.finish()), Arc::new(self.schema_names.finish()), @@ -618,7 +618,7 @@ impl InformationSchemaViews { schema_names: StringBuilder::new(), table_names: StringBuilder::new(), definitions: StringBuilder::new(), - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), } } } @@ -632,7 +632,7 @@ impl PartitionStream for InformationSchemaViews { let mut builder = self.builder(); let config = self.config.clone(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { config.make_views(&mut builder).await?; @@ -670,7 +670,7 @@ impl InformationSchemaViewBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.catalog_names.finish()), Arc::new(self.schema_names.finish()), @@ -733,7 +733,7 @@ impl InformationSchemaColumns { numeric_scales: UInt64Builder::with_capacity(default_capacity), datetime_precisions: UInt64Builder::with_capacity(default_capacity), interval_types: StringBuilder::new(), - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), } } } @@ -747,7 +747,7 @@ impl PartitionStream for InformationSchemaColumns { let mut builder = self.builder(); let config = self.config.clone(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { config.make_columns(&mut builder).await?; @@ -876,7 +876,7 @@ impl InformationSchemaColumnsBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.catalog_names.finish()), Arc::new(self.schema_names.finish()), @@ -921,7 +921,7 @@ impl InformationSchemata { fn builder(&self) -> InformationSchemataBuilder { InformationSchemataBuilder { - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), catalog_name: StringBuilder::new(), schema_name: StringBuilder::new(), schema_owner: StringBuilder::new(), @@ -967,7 +967,7 @@ impl InformationSchemataBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.catalog_name.finish()), Arc::new(self.schema_name.finish()), @@ -991,7 +991,7 @@ impl PartitionStream for InformationSchemata { let mut builder = self.builder(); let config = self.config.clone(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { config.make_schemata(&mut builder).await; @@ -1023,7 +1023,7 @@ impl InformationSchemaDfSettings { names: StringBuilder::new(), values: StringBuilder::new(), descriptions: StringBuilder::new(), - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), } } } @@ -1037,7 +1037,7 @@ impl PartitionStream for InformationSchemaDfSettings { let config = self.config.clone(); let mut builder = self.builder(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { // create a mem table with the names of tables @@ -1064,7 +1064,7 @@ impl InformationSchemaDfSettingsBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.names.finish()), Arc::new(self.values.finish()), @@ -1102,7 +1102,7 @@ impl InformationSchemaRoutines { fn builder(&self) -> InformationSchemaRoutinesBuilder { InformationSchemaRoutinesBuilder { - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), specific_catalog: StringBuilder::new(), specific_schema: StringBuilder::new(), specific_name: StringBuilder::new(), @@ -1161,7 +1161,7 @@ impl InformationSchemaRoutinesBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.specific_catalog.finish()), Arc::new(self.specific_schema.finish()), @@ -1189,7 +1189,7 @@ impl PartitionStream for InformationSchemaRoutines { let config = self.config.clone(); let mut builder = self.builder(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), futures::stream::once(async move { config.make_routines( ctx.scalar_functions(), @@ -1229,7 +1229,7 @@ impl InformationSchemaParameters { fn builder(&self) -> InformationSchemaParametersBuilder { InformationSchemaParametersBuilder { - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), specific_catalog: StringBuilder::new(), specific_schema: StringBuilder::new(), specific_name: StringBuilder::new(), @@ -1295,7 +1295,7 @@ impl InformationSchemaParametersBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.specific_catalog.finish()), Arc::new(self.specific_schema.finish()), @@ -1321,7 +1321,7 @@ impl PartitionStream for InformationSchemaParameters { let config = self.config.clone(); let mut builder = self.builder(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), futures::stream::once(async move { config.make_parameters( ctx.scalar_functions(), diff --git a/datafusion/core/src/catalog_common/listing_schema.rs b/datafusion/core/src/catalog_common/listing_schema.rs index 67952770f41c..dc55a07ef82d 100644 --- a/datafusion/core/src/catalog_common/listing_schema.rs +++ b/datafusion/core/src/catalog_common/listing_schema.rs @@ -148,7 +148,8 @@ impl ListingSchemaProvider { }, ) .await?; - let _ = self.register_table(table_name.to_string(), provider.clone())?; + let _ = + self.register_table(table_name.to_string(), Arc::clone(&provider))?; } } Ok(()) @@ -190,7 +191,7 @@ impl SchemaProvider for ListingSchemaProvider { self.tables .lock() .expect("Can't lock tables") - .insert(name, table.clone()); + .insert(name, Arc::clone(&table)); Ok(Some(table)) } diff --git a/datafusion/core/src/catalog_common/memory.rs b/datafusion/core/src/catalog_common/memory.rs index f25146616891..6cdefc31f18c 100644 --- a/datafusion/core/src/catalog_common/memory.rs +++ b/datafusion/core/src/catalog_common/memory.rs @@ -67,7 +67,7 @@ impl CatalogProviderList for MemoryCatalogProviderList { } fn catalog(&self, name: &str) -> Option> { - self.catalogs.get(name).map(|c| c.value().clone()) + self.catalogs.get(name).map(|c| Arc::clone(c.value())) } } @@ -102,7 +102,7 @@ impl CatalogProvider for MemoryCatalogProvider { } fn schema(&self, name: &str) -> Option> { - self.schemas.get(name).map(|s| s.value().clone()) + self.schemas.get(name).map(|s| Arc::clone(s.value())) } fn register_schema( @@ -175,7 +175,7 @@ impl SchemaProvider for MemorySchemaProvider { &self, name: &str, ) -> datafusion_common::Result>, DataFusionError> { - Ok(self.tables.get(name).map(|table| table.value().clone())) + Ok(self.tables.get(name).map(|table| Arc::clone(table.value()))) } fn register_table( diff --git a/datafusion/core/src/datasource/avro_to_arrow/reader.rs b/datafusion/core/src/datasource/avro_to_arrow/reader.rs index 5dc53c5c86c8..e6310cec7475 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/reader.rs @@ -142,7 +142,7 @@ impl<'a, R: Read> Reader<'a, R> { Ok(Self { array_reader: AvroArrowArrayReader::try_new( reader, - schema.clone(), + Arc::clone(&schema), projection, )?, schema, @@ -153,7 +153,7 @@ impl<'a, R: Read> Reader<'a, R> { /// Returns the schema of the reader, useful for getting the schema without reading /// record batches pub fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } diff --git a/datafusion/core/src/datasource/cte_worktable.rs b/datafusion/core/src/datasource/cte_worktable.rs index 9504721cecd6..b63755f644a8 100644 --- a/datafusion/core/src/datasource/cte_worktable.rs +++ b/datafusion/core/src/datasource/cte_worktable.rs @@ -77,7 +77,7 @@ impl TableProvider for CteWorkTable { } fn schema(&self) -> SchemaRef { - self.table_schema.clone() + Arc::clone(&self.table_schema) } fn table_type(&self) -> TableType { @@ -94,7 +94,7 @@ impl TableProvider for CteWorkTable { // TODO: pushdown filters and limits Ok(Arc::new(WorkTableExec::new( self.name.clone(), - self.table_schema.clone(), + Arc::clone(&self.table_schema), ))) } diff --git a/datafusion/core/src/datasource/default_table_source.rs b/datafusion/core/src/datasource/default_table_source.rs index b4a5a76fc9ff..c37c3b97f4fe 100644 --- a/datafusion/core/src/datasource/default_table_source.rs +++ b/datafusion/core/src/datasource/default_table_source.rs @@ -96,7 +96,7 @@ pub fn source_as_provider( .as_any() .downcast_ref::() { - Some(source) => Ok(source.table_provider.clone()), + Some(source) => Ok(Arc::clone(&source.table_provider)), _ => internal_err!("TableSource was not DefaultTableSource"), } } diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index bc5b82bd8c5b..abda7fa9ec4b 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -61,7 +61,7 @@ impl TableProvider for EmptyTable { } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index c10ebbd6c9ea..fd0935c6e031 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -186,7 +186,7 @@ impl FileFormat for ArrowFormat { return not_impl_err!("Overwrites are not implemented yet for Arrow format"); } - let sink_schema = conf.output_schema().clone(); + let sink_schema = Arc::clone(conf.output_schema()); let sink = Arc::new(ArrowFileSink::new(conf)); Ok(Arc::new(DataSinkExec::new( @@ -229,7 +229,7 @@ impl ArrowFileSink { .collect::>(), )) } else { - self.config.output_schema().clone() + Arc::clone(self.config.output_schema()) } } } @@ -302,7 +302,7 @@ impl DataSink for ArrowFileSink { let mut object_store_writer = create_writer( FileCompressionType::UNCOMPRESSED, &path, - object_store.clone(), + Arc::clone(&object_store), ) .await?; file_write_tasks.spawn(async move { diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 2b5570dc334c..d59e2bf71d64 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -417,7 +417,7 @@ impl FileFormat for CsvFormat { let writer_options = CsvWriterOptions::try_from(&options)?; - let sink_schema = conf.output_schema().clone(); + let sink_schema = Arc::clone(conf.output_schema()); let sink = Arc::new(CsvSink::new(conf, writer_options)); Ok(Arc::new(DataSinkExec::new( diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 6a7bfd2040f0..4f51dd5ae1f5 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -264,7 +264,7 @@ impl FileFormat for JsonFormat { let writer_options = JsonWriterOptions::try_from(&self.options)?; - let sink_schema = conf.output_schema().clone(); + let sink_schema = Arc::clone(conf.output_schema()); let sink = Arc::new(JsonSink::new(conf, writer_options)); Ok(Arc::new(DataSinkExec::new( diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index b0e1df51839d..5c9eb7f20ae2 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -230,7 +230,7 @@ pub fn file_type_to_format( .as_any() .downcast_ref::() { - Some(source) => Ok(source.file_format_factory.clone()), + Some(source) => Ok(Arc::clone(&source.file_format_factory)), _ => internal_err!("FileType was not DefaultFileType"), } } @@ -255,7 +255,7 @@ pub fn transform_schema_to_view(schema: &Schema) -> Schema { DataType::Binary | DataType::LargeBinary => { field_with_new_type(field, DataType::BinaryView) } - _ => field.clone(), + _ => Arc::clone(field), }) .collect(); Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) @@ -297,7 +297,7 @@ pub(crate) fn coerce_file_schema_to_view_type( Some(DataType::BinaryView), DataType::Binary | DataType::LargeBinary, ) => field_with_new_type(field, DataType::BinaryView), - _ => field.clone(), + _ => Arc::clone(field), }, ) .collect(); @@ -317,7 +317,7 @@ pub fn transform_binary_to_string(schema: &Schema) -> Schema { DataType::Binary => field_with_new_type(field, DataType::Utf8), DataType::LargeBinary => field_with_new_type(field, DataType::LargeUtf8), DataType::BinaryView => field_with_new_type(field, DataType::Utf8View), - _ => field.clone(), + _ => Arc::clone(field), }) .collect(); Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) @@ -365,7 +365,7 @@ pub(crate) fn coerce_file_schema_to_string_type( transform = true; field_with_new_type(field, DataType::Utf8View) } - _ => field.clone(), + _ => Arc::clone(field), }, ) .collect(); diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index e27a13b6e735..c1314bdb8641 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -434,7 +434,7 @@ impl FileFormat for ParquetFormat { return not_impl_err!("Overwrites are not implemented yet for Parquet"); } - let sink_schema = conf.output_schema().clone(); + let sink_schema = Arc::clone(conf.output_schema()); let sink = Arc::new(ParquetSink::new(conf, self.options.clone())); Ok(Arc::new(DataSinkExec::new( @@ -748,7 +748,7 @@ impl ParquetSink { schema.metadata().clone(), )) } else { - self.config.output_schema().clone() + Arc::clone(self.config.output_schema()) } } @@ -833,7 +833,7 @@ impl DataSink for ParquetSink { let mut writer = self .create_async_arrow_writer( &path, - object_store.clone(), + Arc::clone(&object_store), parquet_props.writer_options().clone(), ) .await?; @@ -857,7 +857,7 @@ impl DataSink for ParquetSink { // manage compressed blocks themselves. FileCompressionType::UNCOMPRESSED, &path, - object_store.clone(), + Arc::clone(&object_store), ) .await?; let schema = self.get_writer_schema(); @@ -1044,8 +1044,8 @@ fn spawn_parquet_parallel_serialization_task( let max_row_group_rows = writer_props.max_row_group_size(); let (mut column_writer_handles, mut col_array_channels) = spawn_column_parallel_row_group_writer( - schema.clone(), - writer_props.clone(), + Arc::clone(&schema), + Arc::clone(&writer_props), max_buffer_rb, &pool, )?; @@ -1057,15 +1057,23 @@ fn spawn_parquet_parallel_serialization_task( // function. loop { if current_rg_rows + rb.num_rows() < max_row_group_rows { - send_arrays_to_col_writers(&col_array_channels, &rb, schema.clone()) - .await?; + send_arrays_to_col_writers( + &col_array_channels, + &rb, + Arc::clone(&schema), + ) + .await?; current_rg_rows += rb.num_rows(); break; } else { let rows_left = max_row_group_rows - current_rg_rows; let a = rb.slice(0, rows_left); - send_arrays_to_col_writers(&col_array_channels, &a, schema.clone()) - .await?; + send_arrays_to_col_writers( + &col_array_channels, + &a, + Arc::clone(&schema), + ) + .await?; // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup // on a separate task, so that we can immediately start on the next RG before waiting @@ -1088,8 +1096,8 @@ fn spawn_parquet_parallel_serialization_task( (column_writer_handles, col_array_channels) = spawn_column_parallel_row_group_writer( - schema.clone(), - writer_props.clone(), + Arc::clone(&schema), + Arc::clone(&writer_props), max_buffer_rb, &pool, )?; @@ -1192,15 +1200,15 @@ async fn output_single_parquet_file_parallelized( let launch_serialization_task = spawn_parquet_parallel_serialization_task( data, serialize_tx, - output_schema.clone(), - arc_props.clone(), + Arc::clone(&output_schema), + Arc::clone(&arc_props), parallel_options, Arc::clone(&pool), ); let file_metadata = concatenate_parallel_row_groups( serialize_rx, - output_schema.clone(), - arc_props.clone(), + Arc::clone(&output_schema), + Arc::clone(&arc_props), object_store_writer, pool, ) diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 56ded495c8a8..71cf747c328d 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -100,7 +100,7 @@ pub(crate) fn start_demuxer_task( keep_partition_by_columns: bool, ) -> (SpawnedTask>, DemuxedStreamReceiver) { let (tx, rx) = mpsc::unbounded_channel(); - let context = context.clone(); + let context = Arc::clone(context); let single_file_output = !base_output_path.is_collection() && base_output_path.file_extension().is_some(); let task = match partition_by { @@ -478,7 +478,7 @@ fn remove_partition_by_columns( .zip(parted_batch.schema().fields()) .filter_map(|(a, f)| { if !partition_names.contains(&f.name()) { - Some((a.clone(), (**f).clone())) + Some((Arc::clone(a), (**f).clone())) } else { None } diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 6f27e6f3889f..457481b5ad76 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -94,7 +94,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( // subsequent batches, so we track that here. let mut initial = true; while let Some(batch) = data_rx.recv().await { - let serializer_clone = serializer.clone(); + let serializer_clone = Arc::clone(&serializer); let task = SpawnedTask::spawn(async move { let num_rows = batch.num_rows(); let bytes = serializer_clone.serialize(batch, initial)?; @@ -279,7 +279,8 @@ pub(crate) async fn stateless_multipart_put( }); while let Some((location, rb_stream)) = file_stream_rx.recv().await { let serializer = get_serializer(); - let writer = create_writer(compression, &location, object_store.clone()).await?; + let writer = + create_writer(compression, &location, Arc::clone(&object_store)).await?; tx_file_bundle .send((rb_stream, serializer, writer)) diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index b937a28e9332..ffe49dd2ba11 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -1137,14 +1137,14 @@ impl ListingTable { .infer_stats( ctx, store, - self.file_schema.clone(), + Arc::clone(&self.file_schema), &part_file.object_meta, ) .await?; let statistics = Arc::new(statistics); self.collected_statistics.put_with_extra( &part_file.object_meta.location, - statistics.clone(), + Arc::clone(&statistics), &part_file.object_meta, ); Ok(statistics) diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 1f6a19ceb55c..636d1623c5e9 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -127,7 +127,7 @@ impl TableProviderFactory for ListingTableFactory { // See: https://github.com/apache/datafusion/issues/7317 None => { let schema = options.infer_schema(session_state, &table_path).await?; - let df_schema = schema.clone().to_dfschema()?; + let df_schema = Arc::clone(&schema).to_dfschema()?; let column_refs: HashSet<_> = cmd .order_exprs .iter() diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 3c2d1b0205d6..c1e0bea0b3ff 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -139,7 +139,7 @@ impl MemTable { for part_idx in 0..partition_count { let task = state.task_ctx(); - let exec = exec.clone(); + let exec = Arc::clone(&exec); join_set.spawn(async move { let stream = exec.execute(part_idx, task)?; common::collect(stream).await @@ -162,7 +162,7 @@ impl MemTable { } } - let exec = MemoryExec::try_new(&data, schema.clone(), None)?; + let exec = MemoryExec::try_new(&data, Arc::clone(&schema), None)?; if let Some(num_partitions) = output_partitions { let exec = RepartitionExec::try_new( @@ -183,9 +183,9 @@ impl MemTable { output_partitions.push(batches); } - return MemTable::try_new(schema.clone(), output_partitions); + return MemTable::try_new(Arc::clone(&schema), output_partitions); } - MemTable::try_new(schema.clone(), data) + MemTable::try_new(Arc::clone(&schema), data) } } @@ -196,7 +196,7 @@ impl TableProvider for MemTable { } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn constraints(&self) -> Option<&Constraints> { @@ -297,7 +297,7 @@ impl TableProvider for MemTable { Ok(Arc::new(DataSinkExec::new( input, sink, - self.schema.clone(), + Arc::clone(&self.schema), None, ))) } diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 39625a55ca15..df5ede5e8391 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -63,7 +63,7 @@ impl ArrowExec { let (projected_schema, projected_statistics, projected_output_ordering) = base_config.project(); let cache = Self::compute_properties( - projected_schema.clone(), + Arc::clone(&projected_schema), &projected_output_ordering, &base_config, ); @@ -207,7 +207,7 @@ impl ExecutionPlan for ArrowExec { Some(Arc::new(Self { base_config: new_config, projected_statistics: self.projected_statistics.clone(), - projected_schema: self.projected_schema.clone(), + projected_schema: Arc::clone(&self.projected_schema), projected_output_ordering: self.projected_output_ordering.clone(), metrics: self.metrics.clone(), cache: self.cache.clone(), @@ -222,7 +222,7 @@ pub struct ArrowOpener { impl FileOpener for ArrowOpener { fn open(&self, file_meta: FileMeta) -> Result { - let object_store = self.object_store.clone(); + let object_store = Arc::clone(&self.object_store); let projection = self.projection.clone(); Ok(Box::pin(async move { let range = file_meta.range.clone(); diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index ce72c4087424..2e83be212f8b 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -51,7 +51,7 @@ impl AvroExec { let (projected_schema, projected_statistics, projected_output_ordering) = base_config.project(); let cache = Self::compute_properties( - projected_schema.clone(), + Arc::clone(&projected_schema), &projected_output_ordering, &base_config, ); @@ -175,7 +175,7 @@ impl ExecutionPlan for AvroExec { Some(Arc::new(Self { base_config: new_config, projected_statistics: self.projected_statistics.clone(), - projected_schema: self.projected_schema.clone(), + projected_schema: Arc::clone(&self.projected_schema), projected_output_ordering: self.projected_output_ordering.clone(), metrics: self.metrics.clone(), cache: self.cache.clone(), @@ -205,7 +205,7 @@ mod private { fn open(&self, reader: R) -> Result> { AvroReader::try_new( reader, - self.schema.clone(), + Arc::clone(&self.schema), self.batch_size, self.projection.clone(), ) @@ -218,7 +218,7 @@ mod private { impl FileOpener for AvroOpener { fn open(&self, file_meta: FileMeta) -> Result { - let config = self.config.clone(); + let config = Arc::clone(&self.config); Ok(Box::pin(async move { let r = config.object_store.get(file_meta.location()).await?; match r.payload { diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 5beffc3b0581..1679acf30342 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -521,7 +521,7 @@ impl CsvConfig { } fn builder(&self) -> csv::ReaderBuilder { - let mut builder = csv::ReaderBuilder::new(self.file_schema.clone()) + let mut builder = csv::ReaderBuilder::new(Arc::clone(&self.file_schema)) .with_delimiter(self.delimiter) .with_batch_size(self.batch_size) .with_header(self.has_header) @@ -611,7 +611,7 @@ impl FileOpener for CsvOpener { ); } - let store = self.config.object_store.clone(); + let store = Arc::clone(&self.config.object_store); Ok(Box::pin(async move { // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) @@ -698,12 +698,12 @@ pub async fn plan_to_csv( let store = task_ctx.runtime_env().object_store(&object_store_url)?; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { - let storeref = store.clone(); - let plan: Arc = plan.clone(); + let storeref = Arc::clone(&store); + let plan: Arc = Arc::clone(&plan); let filename = format!("{}/part-{i}.csv", parsed.prefix()); let file = object_store::path::Path::parse(filename)?; - let mut stream = plan.execute(i, task_ctx.clone())?; + let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { let mut buf_writer = BufWriter::new(storeref, file.clone()); let mut buffer = Vec::with_capacity(1024); diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index 9d78a0f2e3f8..18cda4524ab2 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -24,6 +24,7 @@ use std::collections::VecDeque; use std::mem; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; use crate::datasource::listing::PartitionedFile; @@ -252,7 +253,7 @@ impl FileStream { ) -> Result { let (projected_schema, ..) = config.project(); let pc_projector = PartitionColumnProjector::new( - projected_schema.clone(), + Arc::clone(&projected_schema), &config .table_partition_cols .iter() @@ -510,7 +511,7 @@ impl Stream for FileStream { impl RecordBatchStream for FileStream { fn schema(&self) -> SchemaRef { - self.projected_schema.clone() + Arc::clone(&self.projected_schema) } } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index cf8f129a5036..6cb9d9df7047 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -262,8 +262,8 @@ impl FileOpener for JsonOpener { /// /// See [`CsvOpener`](super::CsvOpener) for an example. fn open(&self, file_meta: FileMeta) -> Result { - let store = self.object_store.clone(); - let schema = self.projected_schema.clone(); + let store = Arc::clone(&self.object_store); + let schema = Arc::clone(&self.projected_schema); let batch_size = self.batch_size; let file_compression_type = self.file_compression_type.to_owned(); @@ -355,12 +355,12 @@ pub async fn plan_to_json( let store = task_ctx.runtime_env().object_store(&object_store_url)?; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { - let storeref = store.clone(); - let plan: Arc = plan.clone(); + let storeref = Arc::clone(&store); + let plan: Arc = Arc::clone(&plan); let filename = format!("{}/part-{i}.json", parsed.prefix()); let file = object_store::path::Path::parse(filename)?; - let mut stream = plan.execute(i, task_ctx.clone())?; + let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { let mut buf_writer = BufWriter::new(storeref, file.clone()); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 980f796a53b2..9dd0b9e206a9 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -425,7 +425,7 @@ impl ParquetExecBuilder { let pruning_predicate = predicate .clone() .and_then(|predicate_expr| { - match PruningPredicate::try_new(predicate_expr, file_schema.clone()) { + match PruningPredicate::try_new(predicate_expr, Arc::clone(file_schema)) { Ok(pruning_predicate) => Some(Arc::new(pruning_predicate)), Err(e) => { debug!("Could not create pruning predicate for: {e}"); @@ -439,7 +439,7 @@ impl ParquetExecBuilder { let page_pruning_predicate = predicate .as_ref() .map(|predicate_expr| { - PagePruningAccessPlanFilter::new(predicate_expr, file_schema.clone()) + PagePruningAccessPlanFilter::new(predicate_expr, Arc::clone(file_schema)) }) .map(Arc::new); @@ -807,7 +807,7 @@ impl ExecutionPlan for ParquetExec { predicate: self.predicate.clone(), pruning_predicate: self.pruning_predicate.clone(), page_pruning_predicate: self.page_pruning_predicate.clone(), - table_schema: self.base_config.file_schema.clone(), + table_schema: Arc::clone(&self.base_config.file_schema), metadata_size_hint: self.metadata_size_hint, metrics: self.metrics.clone(), parquet_file_reader_factory, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index 3492bcc85f02..883f296f3b95 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -105,11 +105,11 @@ impl FileOpener for ParquetOpener { SchemaRef::from(self.table_schema.project(&self.projection)?); let schema_adapter = self .schema_adapter_factory - .create(projected_schema, self.table_schema.clone()); + .create(projected_schema, Arc::clone(&self.table_schema)); let predicate = self.predicate.clone(); let pruning_predicate = self.pruning_predicate.clone(); let page_pruning_predicate = self.page_pruning_predicate.clone(); - let table_schema = self.table_schema.clone(); + let table_schema = Arc::clone(&self.table_schema); let reorder_predicates = self.reorder_filters; let pushdown_filters = self.pushdown_filters; let enable_page_index = should_enable_page_index( diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index ced07de974f6..07f50bca1d1d 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -118,14 +118,16 @@ impl PagePruningAccessPlanFilter { let predicates = split_conjunction(expr) .into_iter() .filter_map(|predicate| { - let pp = - match PruningPredicate::try_new(predicate.clone(), schema.clone()) { - Ok(pp) => pp, - Err(e) => { - debug!("Ignoring error creating page pruning predicate: {e}"); - return None; - } - }; + let pp = match PruningPredicate::try_new( + Arc::clone(predicate), + Arc::clone(&schema), + ) { + Ok(pp) => pp, + Err(e) => { + debug!("Ignoring error creating page pruning predicate: {e}"); + return None; + } + }; if pp.always_true() { debug!("Ignoring always true page pruning predicate: {predicate}"); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index e876f840d1eb..a97e7c7d2552 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -541,7 +541,7 @@ pub fn build_row_filter( let mut candidates: Vec = predicates .into_iter() .map(|expr| { - FilterCandidateBuilder::new(expr.clone(), file_schema, table_schema) + FilterCandidateBuilder::new(Arc::clone(expr), file_schema, table_schema) .build(metadata) }) .collect::, _>>()? diff --git a/datafusion/core/src/datasource/physical_plan/parquet/writer.rs b/datafusion/core/src/datasource/physical_plan/parquet/writer.rs index 0c0c54691068..00926dc2330b 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/writer.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/writer.rs @@ -40,14 +40,14 @@ pub async fn plan_to_parquet( let store = task_ctx.runtime_env().object_store(&object_store_url)?; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { - let plan: Arc = plan.clone(); + let plan: Arc = Arc::clone(&plan); let filename = format!("{}/part-{i}.parquet", parsed.prefix()); let file = Path::parse(filename)?; let propclone = writer_properties.clone(); - let storeref = store.clone(); + let storeref = Arc::clone(&store); let buf_writer = BufWriter::new(storeref, file.clone()); - let mut stream = plan.execute(i, task_ctx.clone())?; + let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { let mut writer = AsyncArrowWriter::try_new(buf_writer, plan.schema(), propclone)?; diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index 71f947e2c039..b27cf9c5f833 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -288,9 +288,9 @@ impl SchemaAdapter for DefaultSchemaAdapter { Ok(( Arc::new(SchemaMapping { - projected_table_schema: self.projected_table_schema.clone(), + projected_table_schema: Arc::clone(&self.projected_table_schema), field_mappings, - table_schema: self.table_schema.clone(), + table_schema: Arc::clone(&self.table_schema), }), projection, )) @@ -372,7 +372,7 @@ impl SchemaMapper for SchemaMapping { // Necessary to handle empty batches let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - let schema = self.projected_table_schema.clone(); + let schema = Arc::clone(&self.projected_table_schema); let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; Ok(record_batch) } diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index 34023fbbb620..d8fad5b6cd37 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -187,7 +187,7 @@ impl StreamProvider for FileStreamProvider { fn reader(&self) -> Result> { let file = File::open(&self.location)?; - let schema = self.schema.clone(); + let schema = Arc::clone(&self.schema); match &self.encoding { StreamEncoding::Csv => { let reader = arrow::csv::ReaderBuilder::new(schema) @@ -311,7 +311,7 @@ impl TableProvider for StreamTable { } fn schema(&self) -> SchemaRef { - self.0.source.schema().clone() + Arc::clone(self.0.source.schema()) } fn constraints(&self) -> Option<&Constraints> { @@ -338,8 +338,8 @@ impl TableProvider for StreamTable { }; Ok(Arc::new(StreamingTableExec::try_new( - self.0.source.schema().clone(), - vec![Arc::new(StreamRead(self.0.clone())) as _], + Arc::clone(self.0.source.schema()), + vec![Arc::new(StreamRead(Arc::clone(&self.0))) as _], projection, projected_schema, true, @@ -365,8 +365,8 @@ impl TableProvider for StreamTable { Ok(Arc::new(DataSinkExec::new( input, - Arc::new(StreamWrite(self.0.clone())), - self.0.source.schema().clone(), + Arc::new(StreamWrite(Arc::clone(&self.0))), + Arc::clone(self.0.source.schema()), ordering, ))) } @@ -381,8 +381,8 @@ impl PartitionStream for StreamRead { } fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - let config = self.0.clone(); - let schema = self.0.source.schema().clone(); + let config = Arc::clone(&self.0); + let schema = Arc::clone(self.0.source.schema()); let mut builder = RecordBatchReceiverStreamBuilder::new(schema, 2); let tx = builder.tx(); builder.spawn_blocking(move || { @@ -422,7 +422,7 @@ impl DataSink for StreamWrite { mut data: SendableRecordBatchStream, _context: &Arc, ) -> Result { - let config = self.0.clone(); + let config = Arc::clone(&self.0); let (sender, mut receiver) = tokio::sync::mpsc::channel::(2); // Note: FIFO Files support poll so this could use AsyncFd let write_task = SpawnedTask::spawn_blocking(move || { diff --git a/datafusion/core/src/datasource/streaming.rs b/datafusion/core/src/datasource/streaming.rs index 0a14cfefcdf2..1da3c3da9c89 100644 --- a/datafusion/core/src/datasource/streaming.rs +++ b/datafusion/core/src/datasource/streaming.rs @@ -76,7 +76,7 @@ impl TableProvider for StreamingTable { } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { @@ -91,7 +91,7 @@ impl TableProvider for StreamingTable { limit: Option, ) -> Result> { Ok(Arc::new(StreamingTableExec::try_new( - self.schema.clone(), + Arc::clone(&self.schema), self.partitions.clone(), projection, None, diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index e04fe6bddec9..a2093c39fc7b 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -508,7 +508,7 @@ impl SessionContext { /// Return the [RuntimeEnv] used to run queries with this `SessionContext` pub fn runtime_env(&self) -> Arc { - self.state.read().runtime_env().clone() + Arc::clone(self.state.read().runtime_env()) } /// Returns an id that uniquely identifies this `SessionContext`. @@ -1545,7 +1545,7 @@ impl SessionContext { /// Get reference to [`SessionState`] pub fn state_ref(&self) -> Arc> { - self.state.clone() + Arc::clone(&self.state) } /// Get weak reference to [`SessionState`] diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index d0bbc95a1b08..6172783ab832 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -300,9 +300,9 @@ impl SessionState { let resolved_ref = self.resolve_table_ref(table_ref); if self.config.information_schema() && *resolved_ref.schema == *INFORMATION_SCHEMA { - return Ok(Arc::new(InformationSchemaProvider::new( - self.catalog_list.clone(), - ))); + return Ok(Arc::new(InformationSchemaProvider::new(Arc::clone( + &self.catalog_list, + )))); } self.catalog_list @@ -649,9 +649,9 @@ impl SessionState { return Ok(LogicalPlan::Explain(Explain { verbose: e.verbose, - plan: e.plan.clone(), + plan: Arc::clone(&e.plan), stringified_plans, - schema: e.schema.clone(), + schema: Arc::clone(&e.schema), logical_optimization_succeeded: false, })); } @@ -678,7 +678,7 @@ impl SessionState { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; stringified_plans .push(StringifiedPlan::new(plan_type, err.to_string())); - (e.plan.clone(), false) + (Arc::clone(&e.plan), false) } Err(e) => return Err(e), }; @@ -687,7 +687,7 @@ impl SessionState { verbose: e.verbose, plan, stringified_plans, - schema: e.schema.clone(), + schema: Arc::clone(&e.schema), logical_optimization_succeeded, })) } else { @@ -909,7 +909,7 @@ impl SessionState { name: &str, ) -> datafusion_common::Result>> { let udtf = self.table_functions.remove(name); - Ok(udtf.map(|x| x.function().clone())) + Ok(udtf.map(|x| Arc::clone(x.function()))) } /// Store the logical plan and the parameter types of a prepared statement. @@ -1711,7 +1711,7 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { .ok_or(plan_datafusion_err!( "There is no registered file format with ext {ext}" )) - .map(|file_type| format_as_file_type(file_type.clone())) + .map(|file_type| format_as_file_type(Arc::clone(file_type))) } } @@ -1749,7 +1749,8 @@ impl FunctionRegistry for SessionState { udf: Arc, ) -> datafusion_common::Result>> { udf.aliases().iter().for_each(|alias| { - self.scalar_functions.insert(alias.clone(), udf.clone()); + self.scalar_functions + .insert(alias.clone(), Arc::clone(&udf)); }); Ok(self.scalar_functions.insert(udf.name().into(), udf)) } @@ -1759,7 +1760,8 @@ impl FunctionRegistry for SessionState { udaf: Arc, ) -> datafusion_common::Result>> { udaf.aliases().iter().for_each(|alias| { - self.aggregate_functions.insert(alias.clone(), udaf.clone()); + self.aggregate_functions + .insert(alias.clone(), Arc::clone(&udaf)); }); Ok(self.aggregate_functions.insert(udaf.name().into(), udaf)) } @@ -1769,7 +1771,8 @@ impl FunctionRegistry for SessionState { udwf: Arc, ) -> datafusion_common::Result>> { udwf.aliases().iter().for_each(|alias| { - self.window_functions.insert(alias.clone(), udwf.clone()); + self.window_functions + .insert(alias.clone(), Arc::clone(&udwf)); }); Ok(self.window_functions.insert(udwf.name().into(), udwf)) } @@ -1863,7 +1866,7 @@ impl From<&SessionState> for TaskContext { state.scalar_functions.clone(), state.aggregate_functions.clone(), state.window_functions.clone(), - state.runtime_env.clone(), + Arc::clone(&state.runtime_env), ) } } diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index b5370efa0a97..7ba332c520c1 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -193,8 +193,13 @@ impl SessionStateDefaults { Some(factory) => factory, _ => return, }; - let schema = - ListingSchemaProvider::new(authority, path, factory.clone(), store, format); + let schema = ListingSchemaProvider::new( + authority, + path, + Arc::clone(factory), + store, + format, + ); let _ = default_catalog .register_schema("default", Arc::new(schema)) .expect("Failed to register default schema"); diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 9d1574f5156e..b2df32a62e44 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] #![warn(missing_docs, clippy::needless_borrow)] //! [DataFusion] is an extensible query engine written in Rust that diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 4bbb995c365e..82fde60de090 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -274,7 +274,7 @@ impl PhysicalOptimizerRule for EnforceDistribution { fn adjust_input_keys_ordering( mut requirements: PlanWithKeyRequirements, ) -> Result> { - let plan = requirements.plan.clone(); + let plan = Arc::clone(&requirements.plan); if let Some(HashJoinExec { left, @@ -295,8 +295,8 @@ fn adjust_input_keys_ordering( Vec, )| { HashJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(left), + Arc::clone(right), new_conditions.0, filter.clone(), join_type, @@ -362,8 +362,8 @@ fn adjust_input_keys_ordering( Vec, )| { SortMergeJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(left), + Arc::clone(right), new_conditions.0, filter.clone(), *join_type, @@ -495,8 +495,8 @@ fn reorder_aggregate_keys( PhysicalGroupBy::new_single(new_group_exprs), agg_exec.aggr_expr().to_vec(), agg_exec.filter_expr().to_vec(), - agg_exec.input().clone(), - agg_exec.input_schema.clone(), + Arc::clone(agg_exec.input()), + Arc::clone(&agg_exec.input_schema), )?); // Build new group expressions that correspond to the output // of the "reordered" aggregator: @@ -514,11 +514,11 @@ fn reorder_aggregate_keys( new_group_by, agg_exec.aggr_expr().to_vec(), agg_exec.filter_expr().to_vec(), - partial_agg.clone(), + Arc::clone(&partial_agg) as _, agg_exec.input_schema(), )?); - agg_node.plan = new_final_agg.clone(); + agg_node.plan = Arc::clone(&new_final_agg) as _; agg_node.data.clear(); agg_node.children = vec![PlanWithKeyRequirements::new( partial_agg as _, @@ -624,8 +624,8 @@ pub(crate) fn reorder_join_keys_to_inputs( } = join_keys; let new_join_on = new_join_conditions(&left_keys, &right_keys); return Ok(Arc::new(HashJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(left), + Arc::clone(right), new_join_on, filter.clone(), join_type, @@ -664,8 +664,8 @@ pub(crate) fn reorder_join_keys_to_inputs( .map(|idx| sort_options[positions[idx]]) .collect(); return SortMergeJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(left), + Arc::clone(right), new_join_on, filter.clone(), *join_type, @@ -726,19 +726,19 @@ fn try_reorder( } else if !equivalence_properties.eq_group().is_empty() { normalized_expected = expected .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect(); normalized_left_keys = join_keys .left_keys .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect(); normalized_right_keys = join_keys .right_keys .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect(); if physical_exprs_equal(&normalized_expected, &normalized_left_keys) @@ -761,8 +761,8 @@ fn try_reorder( let mut new_left_keys = vec![]; let mut new_right_keys = vec![]; for pos in positions.iter() { - new_left_keys.push(join_keys.left_keys[*pos].clone()); - new_right_keys.push(join_keys.right_keys[*pos].clone()); + new_left_keys.push(Arc::clone(&join_keys.left_keys[*pos])); + new_right_keys.push(Arc::clone(&join_keys.right_keys[*pos])); } let pairs = JoinKeyPairs { left_keys: new_left_keys, @@ -800,7 +800,7 @@ fn expected_expr_positions( fn extract_join_keys(on: &[(PhysicalExprRef, PhysicalExprRef)]) -> JoinKeyPairs { let (left_keys, right_keys) = on .iter() - .map(|(l, r)| (l.clone() as _, r.clone() as _)) + .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _)) .unzip(); JoinKeyPairs { left_keys, @@ -815,7 +815,7 @@ fn new_join_conditions( new_left_keys .iter() .zip(new_right_keys.iter()) - .map(|(l_key, r_key)| (l_key.clone(), r_key.clone())) + .map(|(l_key, r_key)| (Arc::clone(l_key), Arc::clone(r_key))) .collect() } @@ -844,8 +844,9 @@ fn add_roundrobin_on_top( // - Usage of order preserving variants is not desirable // (determined by flag `config.optimizer.prefer_existing_sort`) let partitioning = Partitioning::RoundRobinBatch(n_target); - let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? - .with_preserve_order(); + let repartition = + RepartitionExec::try_new(Arc::clone(&input.plan), partitioning)? + .with_preserve_order(); let new_plan = Arc::new(repartition) as _; @@ -902,8 +903,9 @@ fn add_hash_on_top( // - Usage of order preserving variants is not desirable (per the flag // `config.optimizer.prefer_existing_sort`). let partitioning = dist.create_partitioning(n_target); - let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? - .with_preserve_order(); + let repartition = + RepartitionExec::try_new(Arc::clone(&input.plan), partitioning)? + .with_preserve_order(); let plan = Arc::new(repartition) as _; return Ok(DistributionContext::new(plan, true, vec![input])); @@ -941,10 +943,10 @@ fn add_spm_on_top(input: DistributionContext) -> DistributionContext { .output_ordering() .unwrap_or(&LexOrdering::default()) .clone(), - input.plan.clone(), + Arc::clone(&input.plan), )) as _ } else { - Arc::new(CoalescePartitionsExec::new(input.plan.clone())) as _ + Arc::new(CoalescePartitionsExec::new(Arc::clone(&input.plan))) as _ }; DistributionContext::new(new_plan, true, vec![input]) @@ -1020,7 +1022,7 @@ fn replace_order_preserving_variants( .collect::>>()?; if is_sort_preserving_merge(&context.plan) { - let child_plan = context.children[0].plan.clone(); + let child_plan = Arc::clone(&context.children[0].plan); context.plan = Arc::new(CoalescePartitionsExec::new(child_plan)); return Ok(context); } else if let Some(repartition) = @@ -1028,7 +1030,7 @@ fn replace_order_preserving_variants( { if repartition.preserve_order() { context.plan = Arc::new(RepartitionExec::try_new( - context.children[0].plan.clone(), + Arc::clone(&context.children[0].plan), repartition.partitioning().clone(), )?); return Ok(context); @@ -1306,7 +1308,10 @@ fn ensure_distribution( ) .collect::>>()?; - let children_plans = children.iter().map(|c| c.plan.clone()).collect::>(); + let children_plans = children + .iter() + .map(|c| Arc::clone(&c.plan)) + .collect::>(); plan = if plan.as_any().is::() && !config.optimizer.prefer_existing_union diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 636d52ccc9cd..cfc08562f7d7 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -213,7 +213,7 @@ fn replace_with_partial_sort( ) -> Result> { let plan_any = plan.as_any(); if let Some(sort_plan) = plan_any.downcast_ref::() { - let child = sort_plan.children()[0].clone(); + let child = Arc::clone(sort_plan.children()[0]); if !child.execution_mode().is_unbounded() { return Ok(plan); } @@ -233,7 +233,7 @@ fn replace_with_partial_sort( return Ok(Arc::new( PartialSortExec::new( LexOrdering::new(sort_plan.expr().to_vec()), - sort_plan.input().clone(), + Arc::clone(sort_plan.input()), common_prefix_length, ) .with_preserve_partitioning(sort_plan.preserve_partitioning()) @@ -290,7 +290,8 @@ fn parallelize_sorts( requirements = add_sort_above_with_check(requirements, sort_reqs, fetch); - let spm = SortPreservingMergeExec::new(sort_exprs, requirements.plan.clone()); + let spm = + SortPreservingMergeExec::new(sort_exprs, Arc::clone(&requirements.plan)); Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( Arc::new(spm.with_fetch(fetch)), @@ -307,7 +308,7 @@ fn parallelize_sorts( Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( - Arc::new(CoalescePartitionsExec::new(requirements.plan.clone())), + Arc::new(CoalescePartitionsExec::new(Arc::clone(&requirements.plan))), false, vec![requirements], ), @@ -402,7 +403,7 @@ fn analyze_immediate_sort_removal( { // Replace the sort with a sort-preserving merge: let expr = LexOrdering::new(sort_exec.expr().to_vec()); - Arc::new(SortPreservingMergeExec::new(expr, sort_input.clone())) as _ + Arc::new(SortPreservingMergeExec::new(expr, Arc::clone(sort_input))) as _ } else { // Remove the sort: node.children = node.children.swap_remove(0).children; @@ -414,12 +415,16 @@ fn analyze_immediate_sort_removal( .partition_count() == 1 { - Arc::new(GlobalLimitExec::new(sort_input.clone(), 0, Some(fetch))) + Arc::new(GlobalLimitExec::new( + Arc::clone(sort_input), + 0, + Some(fetch), + )) } else { - Arc::new(LocalLimitExec::new(sort_input.clone(), fetch)) + Arc::new(LocalLimitExec::new(Arc::clone(sort_input), fetch)) } } else { - sort_input.clone() + Arc::clone(sort_input) } }; for child in node.children.iter_mut() { @@ -479,7 +484,7 @@ fn adjust_window_sort_removal( // Satisfy the ordering requirement so that the window can run: let mut child_node = window_tree.children.swap_remove(0); child_node = add_sort_above(child_node, reqs, None); - let child_plan = child_node.plan.clone(); + let child_plan = Arc::clone(&child_node.plan); window_tree.children.push(child_node); if window_expr.iter().all(|e| e.uses_bounded_memory()) { @@ -604,12 +609,12 @@ fn remove_corresponding_sort_from_sub_plan( // Replace with variants that do not preserve order. if is_sort_preserving_merge(&node.plan) { node.children = node.children.swap_remove(0).children; - node.plan = node.plan.children().swap_remove(0).clone(); + node.plan = Arc::clone(node.plan.children().swap_remove(0)); } else if let Some(repartition) = node.plan.as_any().downcast_ref::() { node.plan = Arc::new(RepartitionExec::try_new( - node.children[0].plan.clone(), + Arc::clone(&node.children[0].plan), repartition.properties().output_partitioning().clone(), )?) as _; } @@ -620,7 +625,7 @@ fn remove_corresponding_sort_from_sub_plan( { // If there is existing ordering, to preserve ordering use // `SortPreservingMergeExec` instead of a `CoalescePartitionsExec`. - let plan = node.plan.clone(); + let plan = Arc::clone(&node.plan); let plan = if let Some(ordering) = plan.output_ordering() { Arc::new(SortPreservingMergeExec::new( LexOrdering::new(ordering.to_vec()), diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 0312e362afb1..9b2402c6bb87 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -187,7 +187,7 @@ pub fn swap_hash_join( hash_join .on() .iter() - .map(|(l, r)| (r.clone(), l.clone())) + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) .collect(), swap_join_filter(hash_join.filter()), &swap_join_type(*hash_join.join_type()), @@ -289,7 +289,7 @@ fn swap_filter(filter: &JoinFilter) -> JoinFilter { .collect(); JoinFilter::new( - filter.expression().clone(), + Arc::clone(filter.expression()), column_indices, filter.schema().clone(), ) @@ -605,8 +605,8 @@ fn hash_join_convert_symmetric_subrule( let right_order = determine_order(JoinSide::Right); return SymmetricHashJoinExec::try_new( - hash_join.left().clone(), - hash_join.right().clone(), + Arc::clone(hash_join.left()), + Arc::clone(hash_join.right()), hash_join.on().to_vec(), hash_join.filter().cloned(), hash_join.join_type(), diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 5aecf036ce18..2c2ff6d48aec 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -101,7 +101,7 @@ pub fn remove_unnecessary_projections( // If the projection does not cause any change on the input, we can // safely remove it: if is_projection_removable(projection) { - return Ok(Transformed::yes(projection.input().clone())); + return Ok(Transformed::yes(Arc::clone(projection.input()))); } // If it does, check if we can push it under its child(ren): let input = projection.input().as_any(); @@ -261,7 +261,7 @@ fn try_swapping_with_streaming_table( } StreamingTableExec::try_new( - streaming_table.partition_schema().clone(), + Arc::clone(streaming_table.partition_schema()), streaming_table.partitions().clone(), Some(new_projections.as_ref()), lex_orderings, @@ -297,7 +297,7 @@ fn try_unifying_projections( // beneficial as caching mechanism for non-trivial computations. // See discussion in: https://github.com/apache/datafusion/issues/8296 if column_ref_map.iter().any(|(column, count)| { - *count > 1 && !is_expr_trivial(&child.expr()[column.index()].0.clone()) + *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].0)) }) { return Ok(None); } @@ -312,7 +312,7 @@ fn try_unifying_projections( projected_exprs.push((expr, alias.clone())); } - ProjectionExec::try_new(projected_exprs, child.input().clone()) + ProjectionExec::try_new(projected_exprs, Arc::clone(child.input())) .map(|e| Some(Arc::new(e) as _)) } @@ -603,7 +603,7 @@ fn try_embed_projection( // Old projection may contain some alias or expression such as `a + 1` and `CAST('true' AS BOOLEAN)`, but our projection_exprs in hash join just contain column, so we need to create the new projection to keep the original projection. let new_projection = Arc::new(ProjectionExec::try_new( new_projection_exprs, - new_execution_plan.clone(), + Arc::clone(&new_execution_plan) as _, )?); if is_projection_removable(&new_projection) { Ok(Some(new_execution_plan)) @@ -1005,8 +1005,7 @@ fn update_expr( let mut state = RewriteState::Unchanged; - let new_expr = expr - .clone() + let new_expr = Arc::clone(expr) .transform_up(|expr: Arc| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); @@ -1018,7 +1017,9 @@ fn update_expr( if sync_with_child { state = RewriteState::RewrittenValid; // Update the index of `column`: - Ok(Transformed::yes(projected_exprs[column.index()].0.clone())) + Ok(Transformed::yes(Arc::clone( + &projected_exprs[column.index()].0, + ))) } else { // default to invalid, in case we can't find the relevant column state = RewriteState::RewrittenInvalid; @@ -1055,7 +1056,7 @@ fn make_with_child( projection: &ProjectionExec, child: &Arc, ) -> Result> { - ProjectionExec::try_new(projection.expr().to_vec(), child.clone()) + ProjectionExec::try_new(projection.expr().to_vec(), Arc::clone(child)) .map(|e| Arc::new(e) as _) } @@ -1155,8 +1156,7 @@ fn new_columns_for_join_on( .iter() .filter_map(|on| { // Rewrite all columns in `on` - (*on) - .clone() + Arc::clone(*on) .transform(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { // Find the column in the projection expressions @@ -1219,7 +1219,7 @@ fn update_join_filter( == join_filter.column_indices().len()) .then(|| { JoinFilter::new( - join_filter.expression().clone(), + Arc::clone(join_filter.expression()), join_filter .column_indices() .iter() @@ -1302,7 +1302,7 @@ fn new_join_children( ) }) .collect_vec(), - left_child.clone(), + Arc::clone(left_child), )?; let left_size = left_child.schema().fields().len() as i32; let new_right = ProjectionExec::try_new( @@ -1320,7 +1320,7 @@ fn new_join_children( ) }) .collect_vec(), - right_child.clone(), + Arc::clone(right_child), )?; Ok((new_left, new_right)) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index eb03b337779c..89b86471561e 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -504,7 +504,7 @@ impl Default for ConstantUnhandledPredicateHook { impl UnhandledPredicateHook for ConstantUnhandledPredicateHook { fn handle(&self, _expr: &Arc) -> Arc { - self.default.clone() + Arc::clone(&self.default) } } @@ -857,7 +857,7 @@ impl RequiredColumns { Field::new(stat_column.name(), field.data_type().clone(), nullable); self.columns.push((column.clone(), stat_type, stat_field)); } - rewrite_column_expr(column_expr.clone(), column, &stat_column) + rewrite_column_expr(Arc::clone(column_expr), column, &stat_column) } /// rewrite col --> col_min @@ -1130,7 +1130,7 @@ fn rewrite_expr_to_prunable( .is_some() { // `col op lit()` - Ok((column_expr.clone(), op, scalar_expr.clone())) + Ok((Arc::clone(column_expr), op, Arc::clone(scalar_expr))) } else if let Some(cast) = column_expr_any.downcast_ref::() { // `cast(col) op lit()` let arrow_schema: SchemaRef = schema.clone().into(); @@ -1175,8 +1175,8 @@ fn rewrite_expr_to_prunable( .downcast_ref::() .is_some() { - let left = not.arg().clone(); - let right = Arc::new(phys_expr::NotExpr::new(scalar_expr.clone())); + let left = Arc::clone(not.arg()); + let right = Arc::new(phys_expr::NotExpr::new(Arc::clone(scalar_expr))); Ok((left, reverse_operator(op)?, right)) } else { plan_err!("Not with complex expression {column_expr:?} is not supported") @@ -1462,9 +1462,9 @@ fn build_predicate_expression( .iter() .map(|e| { Arc::new(phys_expr::BinaryExpr::new( - in_list.expr().clone(), + Arc::clone(in_list.expr()), eq_op, - e.clone(), + Arc::clone(e), )) as _ }) .reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _) @@ -1483,9 +1483,9 @@ fn build_predicate_expression( let (left, op, right) = { if let Some(bin_expr) = expr_any.downcast_ref::() { ( - bin_expr.left().clone(), + Arc::clone(bin_expr.left()), *bin_expr.op(), - bin_expr.right().clone(), + Arc::clone(bin_expr.right()), ) } else { return unhandled_hook.handle(expr); @@ -1538,11 +1538,11 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( min_column_expr, Operator::NotEq, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )), Operator::Or, Arc::new(phys_expr::BinaryExpr::new( - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), Operator::NotEq, max_column_expr, )), @@ -1557,11 +1557,11 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( min_column_expr, Operator::LtEq, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )), Operator::And, Arc::new(phys_expr::BinaryExpr::new( - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), Operator::LtEq, max_column_expr, )), @@ -1572,7 +1572,7 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( expr_builder.max_column_expr()?, Operator::Gt, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )) } Operator::GtEq => { @@ -1580,7 +1580,7 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( expr_builder.max_column_expr()?, Operator::GtEq, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )) } Operator::Lt => { @@ -1588,7 +1588,7 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( expr_builder.min_column_expr()?, Operator::Lt, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )) } Operator::LtEq => { @@ -1596,7 +1596,7 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( expr_builder.min_column_expr()?, Operator::LtEq, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )) } // other expressions are not supported diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index c80aea411f57..7fc3adf784e2 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -120,7 +120,7 @@ fn plan_with_order_preserving_variants( { // When a `RepartitionExec` doesn't preserve ordering, replace it with // a sort-preserving variant if appropriate: - let child = sort_input.children[0].plan.clone(); + let child = Arc::clone(&sort_input.children[0].plan); let partitioning = sort_input.plan.output_partitioning().clone(); sort_input.plan = Arc::new( RepartitionExec::try_new(child, partitioning)?.with_preserve_order(), @@ -134,7 +134,7 @@ fn plan_with_order_preserving_variants( // replace it with a `SortPreservingMergeExec` if appropriate: let spm = SortPreservingMergeExec::new( LexOrdering::new(ordering.inner.clone()), - child.clone(), + Arc::clone(child), ); sort_input.plan = Arc::new(spm) as _; sort_input.children[0].data = true; @@ -179,12 +179,12 @@ fn plan_with_order_breaking_variants( if is_repartition(plan) && plan.maintains_input_order()[0] { // When a `RepartitionExec` preserves ordering, replace it with a // non-sort-preserving variant: - let child = sort_input.children[0].plan.clone(); + let child = Arc::clone(&sort_input.children[0].plan); let partitioning = plan.output_partitioning().clone(); sort_input.plan = Arc::new(RepartitionExec::try_new(child, partitioning)?) as _; } else if is_sort_preserving_merge(plan) { // Replace `SortPreservingMergeExec` with a `CoalescePartitionsExec`: - let child = sort_input.children[0].plan.clone(); + let child = Arc::clone(&sort_input.children[0].plan); let coalesce = CoalescePartitionsExec::new(child); sort_input.plan = Arc::new(coalesce) as _; } else { diff --git a/datafusion/core/src/physical_optimizer/sanity_checker.rs b/datafusion/core/src/physical_optimizer/sanity_checker.rs index 01d3cd1aab29..b2f2c933c1d1 100644 --- a/datafusion/core/src/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/src/physical_optimizer/sanity_checker.rs @@ -119,7 +119,7 @@ pub fn check_plan_sanity( plan: Arc, optimizer_options: &OptimizerOptions, ) -> Result>> { - check_finiteness_requirements(plan.clone(), optimizer_options)?; + check_finiteness_requirements(Arc::clone(&plan), optimizer_options)?; for ((idx, child), sort_req, dist_req) in izip!( plan.children().iter().enumerate(), diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 42d682169da8..d48c7118cb8e 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -571,9 +571,7 @@ fn handle_custom_pushdown( .iter() .map(|req| { let child_schema = plan.children()[maintained_child_idx].schema(); - let updated_columns = req - .expr - .clone() + let updated_columns = Arc::clone(&req.expr) .transform_up(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { let new_index = col.index() - sub_offset; diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 9acd3f67c272..cdecc9d31862 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -46,7 +46,7 @@ pub fn add_sort_above( .equivalence_properties() .is_expr_constant(&sort_expr.expr) }); - let mut new_sort = SortExec::new(sort_expr, node.plan.clone()).with_fetch(fetch); + let mut new_sort = SortExec::new(sort_expr, Arc::clone(&node.plan)).with_fetch(fetch); if node.plan.output_partitioning().partition_count() > 1 { new_sort = new_sort.with_preserve_partitioning(true); } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 2d3899adb00e..26f6b12908a7 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -327,7 +327,7 @@ impl DefaultPhysicalPlanner { // Spawning tasks which will traverse leaf up to the root. let tasks = flat_tree_leaf_indices .into_iter() - .map(|index| self.task_helper(index, flat_tree.clone(), session_state)); + .map(|index| self.task_helper(index, Arc::clone(&flat_tree), session_state)); let mut outputs = futures::stream::iter(tasks) .buffer_unordered(max_concurrency) .try_collect::>() @@ -486,7 +486,7 @@ impl DefaultPhysicalPlanner { output_schema, }) => { let output_schema: Schema = output_schema.as_ref().into(); - self.plan_describe(schema.clone(), Arc::new(output_schema))? + self.plan_describe(Arc::clone(schema), Arc::new(output_schema))? } // 1 Child @@ -690,7 +690,7 @@ impl DefaultPhysicalPlanner { aggregates, filters.clone(), input_exec, - physical_input_schema.clone(), + Arc::clone(&physical_input_schema), )?); let can_repartition = !groups.is_empty() @@ -721,7 +721,7 @@ impl DefaultPhysicalPlanner { updated_aggregates, filters, initial_aggr, - physical_input_schema.clone(), + Arc::clone(&physical_input_schema), )?) } LogicalPlan::Projection(Projection { input, expr, .. }) => self @@ -893,8 +893,8 @@ impl DefaultPhysicalPlanner { let right = Arc::new(right); let new_join = LogicalPlan::Join(Join::try_new_with_project_input( node, - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), column_on, )?); diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index e03c18fec7c4..c4c84d667a06 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -209,7 +209,7 @@ impl TableProvider for TestTableProvider { } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { @@ -425,7 +425,7 @@ impl TestAggregate { /// Create a new COUNT(column) aggregate pub fn new_count_column(schema: &Arc) -> Self { - Self::ColumnA(schema.clone()) + Self::ColumnA(Arc::clone(schema)) } /// Return appropriate expr depending if COUNT is for col or table (*) diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 4e4a3747208e..fc6b4ad9d0b4 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -87,7 +87,8 @@ impl TestParquetFile { let first_batch = batches.next().expect("need at least one record batch"); let schema = first_batch.schema(); - let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(props)).unwrap(); + let mut writer = + ArrowWriter::try_new(file, Arc::clone(&schema), Some(props)).unwrap(); writer.write(&first_batch).unwrap(); let mut num_rows = first_batch.num_rows(); @@ -144,7 +145,7 @@ impl TestParquetFile { maybe_filter: Option, ) -> Result> { let scan_config = - FileScanConfig::new(self.object_store_url.clone(), self.schema.clone()) + FileScanConfig::new(self.object_store_url.clone(), Arc::clone(&self.schema)) .with_file(PartitionedFile { object_meta: self.object_meta.clone(), partition_values: vec![], @@ -154,11 +155,11 @@ impl TestParquetFile { metadata_size_hint: None, }); - let df_schema = self.schema.clone().to_dfschema_ref()?; + let df_schema = Arc::clone(&self.schema).to_dfschema_ref()?; // run coercion on the filters to coerce types etc. let props = ExecutionProps::new(); - let context = SimplifyContext::new(&props).with_schema(df_schema.clone()); + let context = SimplifyContext::new(&props).with_schema(Arc::clone(&df_schema)); let parquet_options = ctx.copied_table_options().parquet; if let Some(filter) = maybe_filter { let simplifier = ExprSimplifier::new(context); @@ -168,7 +169,7 @@ impl TestParquetFile { let parquet_exec = ParquetExecBuilder::new_with_options(scan_config, parquet_options) - .with_predicate(physical_filter_expr.clone()) + .with_predicate(Arc::clone(&physical_filter_expr)) .build_arc(); let exec = Arc::new(FilterExec::try_new(physical_filter_expr, parquet_exec)?); @@ -200,7 +201,7 @@ impl TestParquetFile { /// The schema of this parquet file pub fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } /// The path to the parquet file From 8d6899ee6bf73ec9d93a165aa972c33a9c2504c9 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 12 Nov 2024 22:13:38 +0800 Subject: [PATCH 09/17] Support TypeSignature::Nullary (#13354) * support zero arg Signed-off-by: jayzhan211 * rename to nullary Signed-off-by: jayzhan211 * rename Signed-off-by: jayzhan211 * tostring Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/expr-common/src/signature.rs | 23 ++++++++---- .../expr/src/type_coercion/functions.rs | 35 ++++++++++++++++--- datafusion/functions-aggregate/src/count.rs | 3 +- datafusion/functions-nested/src/make_array.rs | 2 +- datafusion/functions-window/src/cume_dist.rs | 2 +- datafusion/functions-window/src/rank.rs | 2 +- datafusion/functions-window/src/row_number.rs | 2 +- .../functions/src/datetime/current_date.rs | 2 +- .../functions/src/datetime/current_time.rs | 2 +- datafusion/functions/src/datetime/now.rs | 2 +- datafusion/functions/src/math/pi.rs | 2 +- datafusion/functions/src/math/random.rs | 2 +- 12 files changed, 58 insertions(+), 21 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 3846fae5de5d..0fffd84b7047 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -113,8 +113,7 @@ pub enum TypeSignature { /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]` /// since i32 and f32 can be casted to f64 Coercible(Vec), - /// Fixed number of arguments of arbitrary types - /// If a function takes 0 argument, its `TypeSignature` should be `Any(0)` + /// Fixed number of arguments of arbitrary types, number should be larger than 0 Any(usize), /// Matches exactly one of a list of [`TypeSignature`]s. Coercion is attempted to match /// the signatures in order, and stops after the first success, if any. @@ -135,6 +134,8 @@ pub enum TypeSignature { /// Null is considerd as `Utf8` by default /// Dictionary with string value type is also handled. String(usize), + /// Zero argument + NullAry, } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -191,6 +192,9 @@ impl std::fmt::Display for ArrayFunctionSignature { impl TypeSignature { pub fn to_string_repr(&self) -> Vec { match self { + TypeSignature::NullAry => { + vec!["NullAry()".to_string()] + } TypeSignature::Variadic(types) => { vec![format!("{}, ..", Self::join_types(types, "/"))] } @@ -244,7 +248,7 @@ impl TypeSignature { pub fn supports_zero_argument(&self) -> bool { match &self { TypeSignature::Exact(vec) => vec.is_empty(), - TypeSignature::Uniform(0, _) | TypeSignature::Any(0) => true, + TypeSignature::NullAry => true, TypeSignature::OneOf(types) => types .iter() .any(|type_sig| type_sig.supports_zero_argument()), @@ -287,6 +291,7 @@ impl TypeSignature { .collect(), // TODO: Implement for other types TypeSignature::Any(_) + | TypeSignature::NullAry | TypeSignature::VariadicAny | TypeSignature::ArraySignature(_) | TypeSignature::UserDefined => vec![], @@ -407,6 +412,13 @@ impl Signature { } } + pub fn nullary(volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::NullAry, + volatility, + } + } + /// A specified number of arguments of any type pub fn any(arg_count: usize, volatility: Volatility) -> Self { Signature { @@ -477,13 +489,12 @@ mod tests { // Testing `TypeSignature`s which supports 0 arg let positive_cases = vec![ TypeSignature::Exact(vec![]), - TypeSignature::Uniform(0, vec![DataType::Float64]), - TypeSignature::Any(0), TypeSignature::OneOf(vec![ TypeSignature::Exact(vec![DataType::Int8]), - TypeSignature::Any(0), + TypeSignature::NullAry, TypeSignature::Uniform(1, vec![DataType::Int8]), ]), + TypeSignature::NullAry, ]; for case in positive_cases { diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 5a4d89a0b2ec..6836713d8016 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -181,6 +181,7 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { | TypeSignature::String(_) | TypeSignature::Coercible(_) | TypeSignature::Any(_) + | TypeSignature::NullAry ) } @@ -554,16 +555,27 @@ fn get_valid_types( vec![new_types] } - TypeSignature::Uniform(number, valid_types) => valid_types - .iter() - .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) - .collect(), + TypeSignature::Uniform(number, valid_types) => { + if *number == 0 { + return plan_err!("The function expected at least one argument"); + } + + valid_types + .iter() + .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) + .collect() + } TypeSignature::UserDefined => { return internal_err!( "User-defined signature should be handled by function-specific coerce_types." ) } TypeSignature::VariadicAny => { + if current_types.is_empty() { + return plan_err!( + "The function expected at least one argument but received 0" + ); + } vec![current_types.to_vec()] } TypeSignature::Exact(valid_types) => vec![valid_types.clone()], @@ -606,7 +618,22 @@ fn get_valid_types( } } }, + TypeSignature::NullAry => { + if !current_types.is_empty() { + return plan_err!( + "The function expected zero argument but received {}", + current_types.len() + ); + } + vec![vec![]] + } TypeSignature::Any(number) => { + if current_types.is_empty() { + return plan_err!( + "The function expected at least one argument but received 0" + ); + } + if current_types.len() != *number { return plan_err!( "The function expected {} arguments but received {}", diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index bade589a908a..52181372698f 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -102,8 +102,7 @@ impl Count { pub fn new() -> Self { Self { signature: Signature::one_of( - // TypeSignature::Any(0) is required to handle `Count()` with no args - vec![TypeSignature::VariadicAny, TypeSignature::Any(0)], + vec![TypeSignature::VariadicAny, TypeSignature::NullAry], Volatility::Immutable, ), } diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 7aa3445f6858..de67b0ae3874 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -63,7 +63,7 @@ impl MakeArray { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![TypeSignature::UserDefined, TypeSignature::Any(0)], + vec![TypeSignature::NullAry, TypeSignature::UserDefined], Volatility::Immutable, ), aliases: vec![String::from("make_list")], diff --git a/datafusion/functions-window/src/cume_dist.rs b/datafusion/functions-window/src/cume_dist.rs index 9e30c672fee5..500d96b56323 100644 --- a/datafusion/functions-window/src/cume_dist.rs +++ b/datafusion/functions-window/src/cume_dist.rs @@ -49,7 +49,7 @@ pub struct CumeDist { impl CumeDist { pub fn new() -> Self { Self { - signature: Signature::any(0, Volatility::Immutable), + signature: Signature::nullary(Volatility::Immutable), } } } diff --git a/datafusion/functions-window/src/rank.rs b/datafusion/functions-window/src/rank.rs index 06c3f49055a5..06945e693eea 100644 --- a/datafusion/functions-window/src/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -74,7 +74,7 @@ impl Rank { pub fn new(name: String, rank_type: RankType) -> Self { Self { name, - signature: Signature::any(0, Volatility::Immutable), + signature: Signature::nullary(Volatility::Immutable), rank_type, } } diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index 56af14fb84ae..68f6fde23280 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -51,7 +51,7 @@ impl RowNumber { /// Create a new `row_number` function pub fn new() -> Self { Self { - signature: Signature::any(0, Volatility::Immutable), + signature: Signature::nullary(Volatility::Immutable), } } } diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 24046611a71f..3b819c470d1e 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -44,7 +44,7 @@ impl Default for CurrentDateFunc { impl CurrentDateFunc { pub fn new() -> Self { Self { - signature: Signature::uniform(0, vec![], Volatility::Stable), + signature: Signature::nullary(Volatility::Stable), aliases: vec![String::from("today")], } } diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index 4122b54b07e8..ca591f922305 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -42,7 +42,7 @@ impl Default for CurrentTimeFunc { impl CurrentTimeFunc { pub fn new() -> Self { Self { - signature: Signature::uniform(0, vec![], Volatility::Stable), + signature: Signature::nullary(Volatility::Stable), } } } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index c13bbfb18105..cadc4fce04f1 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -43,7 +43,7 @@ impl Default for NowFunc { impl NowFunc { pub fn new() -> Self { Self { - signature: Signature::uniform(0, vec![], Volatility::Stable), + signature: Signature::nullary(Volatility::Stable), aliases: vec!["current_timestamp".to_string()], } } diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index 502429d0ca5d..70cc76f03c58 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -41,7 +41,7 @@ impl Default for PiFunc { impl PiFunc { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Immutable), + signature: Signature::nullary(Volatility::Immutable), } } } diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index cd92798d67dd..0026037c95bd 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -42,7 +42,7 @@ impl Default for RandomFunc { impl RandomFunc { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), } } } From 705dd0e209629d9d9202bc15ec9ae381a5521d4f Mon Sep 17 00:00:00 2001 From: Dima <111751109+Dimchikkk@users.noreply.github.com> Date: Tue, 12 Nov 2024 18:21:38 +0000 Subject: [PATCH 10/17] improve performance of regexp_count (#13364) * improve performance of regexp_count * fix clippy * collect with Int64Array to eliminate one temp Vec --------- Co-authored-by: Dima --- datafusion/functions/src/regex/regexpcount.rs | 82 +++++++++---------- 1 file changed, 39 insertions(+), 43 deletions(-) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 7c4313effffb..1286c6b5b1bc 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -30,7 +30,6 @@ use datafusion_expr::{ }; use itertools::izip; use regex::Regex; -use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::{Arc, OnceLock}; @@ -312,12 +311,12 @@ where let pattern = compile_regex(regex, flags_scalar)?; - Ok(Arc::new(Int64Array::from_iter_values( + Ok(Arc::new( values .iter() .map(|value| count_matches(value, &pattern, start_scalar)) - .collect::, ArrowError>>()?, - ))) + .collect::>()?, + )) } (true, true, false) => { let regex = match regex_scalar { @@ -336,17 +335,17 @@ where ))); } - Ok(Arc::new(Int64Array::from_iter_values( + Ok(Arc::new( values .iter() .zip(flags_array.iter()) .map(|(value, flags)| { let pattern = compile_and_cache_regex(regex, flags, &mut regex_cache)?; - count_matches(value, &pattern, start_scalar) + count_matches(value, pattern, start_scalar) }) - .collect::, ArrowError>>()?, - ))) + .collect::>()?, + )) } (true, false, true) => { let regex = match regex_scalar { @@ -360,13 +359,13 @@ where let start_array = start_array.unwrap(); - Ok(Arc::new(Int64Array::from_iter_values( + Ok(Arc::new( values .iter() .zip(start_array.iter()) .map(|(value, start)| count_matches(value, &pattern, start)) - .collect::, ArrowError>>()?, - ))) + .collect::>()?, + )) } (true, false, false) => { let regex = match regex_scalar { @@ -385,7 +384,7 @@ where ))); } - Ok(Arc::new(Int64Array::from_iter_values( + Ok(Arc::new( izip!( values.iter(), start_array.unwrap().iter(), @@ -395,10 +394,10 @@ where let pattern = compile_and_cache_regex(regex, flags, &mut regex_cache)?; - count_matches(value, &pattern, start) + count_matches(value, pattern, start) }) - .collect::, ArrowError>>()?, - ))) + .collect::>()?, + )) } (false, true, true) => { if values.len() != regex_array.len() { @@ -409,7 +408,7 @@ where ))); } - Ok(Arc::new(Int64Array::from_iter_values( + Ok(Arc::new( values .iter() .zip(regex_array.iter()) @@ -424,10 +423,10 @@ where flags_scalar, &mut regex_cache, )?; - count_matches(value, &pattern, start_scalar) + count_matches(value, pattern, start_scalar) }) - .collect::, ArrowError>>()?, - ))) + .collect::>()?, + )) } (false, true, false) => { if values.len() != regex_array.len() { @@ -447,7 +446,7 @@ where ))); } - Ok(Arc::new(Int64Array::from_iter_values( + Ok(Arc::new( izip!(values.iter(), regex_array.iter(), flags_array.iter()) .map(|(value, regex, flags)| { let regex = match regex { @@ -458,10 +457,10 @@ where let pattern = compile_and_cache_regex(regex, flags, &mut regex_cache)?; - count_matches(value, &pattern, start_scalar) + count_matches(value, pattern, start_scalar) }) - .collect::, ArrowError>>()?, - ))) + .collect::>()?, + )) } (false, false, true) => { if values.len() != regex_array.len() { @@ -481,7 +480,7 @@ where ))); } - Ok(Arc::new(Int64Array::from_iter_values( + Ok(Arc::new( izip!(values.iter(), regex_array.iter(), start_array.iter()) .map(|(value, regex, start)| { let regex = match regex { @@ -494,10 +493,10 @@ where flags_scalar, &mut regex_cache, )?; - count_matches(value, &pattern, start) + count_matches(value, pattern, start) }) - .collect::, ArrowError>>()?, - ))) + .collect::>()?, + )) } (false, false, false) => { if values.len() != regex_array.len() { @@ -526,7 +525,7 @@ where ))); } - Ok(Arc::new(Int64Array::from_iter_values( + Ok(Arc::new( izip!( values.iter(), regex_array.iter(), @@ -541,27 +540,24 @@ where let pattern = compile_and_cache_regex(regex, flags, &mut regex_cache)?; - count_matches(value, &pattern, start) + count_matches(value, pattern, start) }) - .collect::, ArrowError>>()?, - ))) + .collect::>()?, + )) } } } -fn compile_and_cache_regex( - regex: &str, - flags: Option<&str>, - regex_cache: &mut HashMap, -) -> Result { - match regex_cache.entry(regex.to_string()) { - Entry::Vacant(entry) => { - let compiled = compile_regex(regex, flags)?; - entry.insert(compiled.clone()); - Ok(compiled) - } - Entry::Occupied(entry) => Ok(entry.get().to_owned()), +fn compile_and_cache_regex<'a>( + regex: &'a str, + flags: Option<&'a str>, + regex_cache: &'a mut HashMap, +) -> Result<&'a Regex, ArrowError> { + if !regex_cache.contains_key(regex) { + let compiled = compile_regex(regex, flags)?; + regex_cache.insert(regex.to_string(), compiled); } + Ok(regex_cache.get(regex).unwrap()) } fn compile_regex(regex: &str, flags: Option<&str>) -> Result { From 2a2de82f6789d2f6eee669c762c4a13704a52b12 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 13 Nov 2024 01:55:28 +0100 Subject: [PATCH 11/17] annotate get_type with recursive (#13376) --- datafusion/expr/src/expr_schema.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index f0a6ed89e6e9..2225f457f626 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -32,6 +32,7 @@ use datafusion_common::{ TableReference, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use recursive::recursive; use std::collections::HashMap; use std::sync::Arc; @@ -99,6 +100,7 @@ impl ExprSchemable for Expr { /// expression refers to a column that does not exist in the /// schema, or when the expression is incorrectly typed /// (e.g. `[utf8] + [bool]`). + #[recursive] fn get_type(&self, schema: &dyn ExprSchema) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { From 430a67cdb6664d4f98a7e3b057cdb3024dc44450 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Tue, 12 Nov 2024 20:03:56 -0500 Subject: [PATCH 12/17] Directly support utf8view in nullif. #13379 (#13380) --- datafusion/functions/src/core/nullif.rs | 1 + datafusion/sqllogictest/test_files/nullif.slt | 30 +++++++++++++++++++ .../test_files/string/string_view.slt | 20 +++++++++++++ 3 files changed, 51 insertions(+) diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index f96ee1ea7a12..801a80201946 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -47,6 +47,7 @@ static SUPPORTED_NULLIF_TYPES: &[DataType] = &[ DataType::Int64, DataType::Float32, DataType::Float64, + DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8, ]; diff --git a/datafusion/sqllogictest/test_files/nullif.slt b/datafusion/sqllogictest/test_files/nullif.slt index f8240f70e363..a5060077fe77 100644 --- a/datafusion/sqllogictest/test_files/nullif.slt +++ b/datafusion/sqllogictest/test_files/nullif.slt @@ -101,3 +101,33 @@ query I SELECT NULLIF(NULL, NULL); ---- NULL + +query T +SELECT NULLIF(arrow_cast('a', 'Utf8View'), 'a'); +---- +NULL + +query T +SELECT NULLIF('a', arrow_cast('a', 'Utf8View')); +---- +NULL + +query T +SELECT NULLIF(arrow_cast('a', 'Utf8View'), 'b'); +---- +a + +query T +SELECT NULLIF('a', arrow_cast('b', 'Utf8View')); +---- +a + +query T +SELECT NULLIF(null, arrow_cast('a', 'Utf8View')); +---- +NULL + +query T +SELECT NULLIF(arrow_cast('a', 'Utf8View'), null); +---- +a \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index ce8a295373aa..2f4af80a9257 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -934,6 +934,26 @@ logical_plan 01)Projection: to_timestamp(test.column1_utf8view, Utf8("a,b,c,d")) AS c 02)--TableScan: test projection=[column1_utf8view] +## Ensure no casts for nullif +query TT +EXPLAIN SELECT + nullif(column1_utf8view, 'a') as c +FROM test; +---- +logical_plan +01)Projection: nullif(test.column1_utf8view, Utf8View("a")) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for nullif +query TT +EXPLAIN SELECT + nullif(column1_utf8view, column1_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: nullif(test.column1_utf8view, test.column1_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view] + ## Ensure no casts for binary operators # `~` operator (regex match) query TT From f894c7deb7ed04ec6d2c75c10326d1f6811fa2f4 Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Tue, 12 Nov 2024 22:15:50 -0800 Subject: [PATCH 13/17] update (#13352) --- .../aggregates/group_values/group_column.rs | 2034 ----------------- .../src/aggregates/group_values/mod.rs | 7 +- .../group_values/multi_column/bytes.rs | 633 +++++ .../group_values/multi_column/bytes_view.rs | 911 ++++++++ .../{column.rs => multi_column/mod.rs} | 78 +- .../group_values/multi_column/primitive.rs | 472 ++++ 6 files changed, 2093 insertions(+), 2042 deletions(-) delete mode 100644 datafusion/physical-plan/src/aggregates/group_values/group_column.rs create mode 100644 datafusion/physical-plan/src/aggregates/group_values/multi_column/bytes.rs create mode 100644 datafusion/physical-plan/src/aggregates/group_values/multi_column/bytes_view.rs rename datafusion/physical-plan/src/aggregates/group_values/{column.rs => multi_column/mod.rs} (95%) create mode 100644 datafusion/physical-plan/src/aggregates/group_values/multi_column/primitive.rs diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs deleted file mode 100644 index 1f59c617d883..000000000000 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ /dev/null @@ -1,2034 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::make_view; -use arrow::array::BufferBuilder; -use arrow::array::ByteView; -use arrow::array::GenericBinaryArray; -use arrow::array::GenericStringArray; -use arrow::array::OffsetSizeTrait; -use arrow::array::PrimitiveArray; -use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray}; -use arrow::buffer::OffsetBuffer; -use arrow::buffer::ScalarBuffer; -use arrow::datatypes::ByteArrayType; -use arrow::datatypes::ByteViewType; -use arrow::datatypes::DataType; -use arrow::datatypes::GenericBinaryType; -use arrow_array::GenericByteArray; -use arrow_array::GenericByteViewArray; -use arrow_buffer::Buffer; -use datafusion_common::utils::proxy::VecAllocExt; -use itertools::izip; - -use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; -use arrow_array::types::GenericStringType; -use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; -use std::iter; -use std::marker::PhantomData; -use std::mem::{replace, size_of}; -use std::sync::Arc; -use std::vec; - -const BYTE_VIEW_MAX_BLOCK_SIZE: usize = 2 * 1024 * 1024; - -/// Trait for storing a single column of group values in [`GroupValuesColumn`] -/// -/// Implementations of this trait store an in-progress collection of group values -/// (similar to various builders in Arrow-rs) that allow for quick comparison to -/// incoming rows. -/// -/// [`GroupValuesColumn`]: crate::aggregates::group_values::GroupValuesColumn -pub trait GroupColumn: Send + Sync { - /// Returns equal if the row stored in this builder at `lhs_row` is equal to - /// the row in `array` at `rhs_row` - /// - /// Note that this comparison returns true if both elements are NULL - fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool; - - /// Appends the row at `row` in `array` to this builder - fn append_val(&mut self, array: &ArrayRef, row: usize); - - /// The vectorized version equal to - /// - /// When found nth row stored in this builder at `lhs_row` - /// is equal to the row in `array` at `rhs_row`, - /// it will record the `true` result at the corresponding - /// position in `equal_to_results`. - /// - /// And if found nth result in `equal_to_results` is already - /// `false`, the check for nth row will be skipped. - /// - fn vectorized_equal_to( - &self, - lhs_rows: &[usize], - array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut [bool], - ); - - /// The vectorized version `append_val` - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]); - - /// Returns the number of rows stored in this builder - fn len(&self) -> usize; - - /// Returns the number of bytes used by this [`GroupColumn`] - fn size(&self) -> usize; - - /// Builds a new array from all of the stored rows - fn build(self: Box) -> ArrayRef; - - /// Builds a new array from the first `n` stored rows, shifting the - /// remaining rows to the start of the builder - fn take_n(&mut self, n: usize) -> ArrayRef; -} - -/// An implementation of [`GroupColumn`] for primitive values -/// -/// Optimized to skip null buffer construction if the input is known to be non nullable -/// -/// # Template parameters -/// -/// `T`: the native Rust type that stores the data -/// `NULLABLE`: if the data can contain any nulls -#[derive(Debug)] -pub struct PrimitiveGroupValueBuilder { - group_values: Vec, - nulls: MaybeNullBufferBuilder, -} - -impl PrimitiveGroupValueBuilder -where - T: ArrowPrimitiveType, -{ - /// Create a new `PrimitiveGroupValueBuilder` - pub fn new() -> Self { - Self { - group_values: vec![], - nulls: MaybeNullBufferBuilder::new(), - } - } -} - -impl GroupColumn - for PrimitiveGroupValueBuilder -{ - fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { - // Perf: skip null check (by short circuit) if input is not nullable - if NULLABLE { - let exist_null = self.nulls.is_null(lhs_row); - let input_null = array.is_null(rhs_row); - if let Some(result) = nulls_equal_to(exist_null, input_null) { - return result; - } - // Otherwise, we need to check their values - } - - self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) - } - - fn append_val(&mut self, array: &ArrayRef, row: usize) { - // Perf: skip null check if input can't have nulls - if NULLABLE { - if array.is_null(row) { - self.nulls.append(true); - self.group_values.push(T::default_value()); - } else { - self.nulls.append(false); - self.group_values.push(array.as_primitive::().value(row)); - } - } else { - self.group_values.push(array.as_primitive::().value(row)); - } - } - - fn vectorized_equal_to( - &self, - lhs_rows: &[usize], - array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut [bool], - ) { - let array = array.as_primitive::(); - - let iter = izip!( - lhs_rows.iter(), - rhs_rows.iter(), - equal_to_results.iter_mut(), - ); - - for (&lhs_row, &rhs_row, equal_to_result) in iter { - // Has found not equal to in previous column, don't need to check - if !*equal_to_result { - continue; - } - - // Perf: skip null check (by short circuit) if input is not nullable - if NULLABLE { - let exist_null = self.nulls.is_null(lhs_row); - let input_null = array.is_null(rhs_row); - if let Some(result) = nulls_equal_to(exist_null, input_null) { - *equal_to_result = result; - continue; - } - // Otherwise, we need to check their values - } - - *equal_to_result = self.group_values[lhs_row] == array.value(rhs_row); - } - } - - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { - let arr = array.as_primitive::(); - - let null_count = array.null_count(); - let num_rows = array.len(); - let all_null_or_non_null = if null_count == 0 { - Some(true) - } else if null_count == num_rows { - Some(false) - } else { - None - }; - - match (NULLABLE, all_null_or_non_null) { - (true, None) => { - for &row in rows { - if array.is_null(row) { - self.nulls.append(true); - self.group_values.push(T::default_value()); - } else { - self.nulls.append(false); - self.group_values.push(arr.value(row)); - } - } - } - - (true, Some(true)) => { - self.nulls.append_n(rows.len(), false); - for &row in rows { - self.group_values.push(arr.value(row)); - } - } - - (true, Some(false)) => { - self.nulls.append_n(rows.len(), true); - self.group_values - .extend(iter::repeat(T::default_value()).take(rows.len())); - } - - (false, _) => { - for &row in rows { - self.group_values.push(arr.value(row)); - } - } - } - } - - fn len(&self) -> usize { - self.group_values.len() - } - - fn size(&self) -> usize { - self.group_values.allocated_size() + self.nulls.allocated_size() - } - - fn build(self: Box) -> ArrayRef { - let Self { - group_values, - nulls, - } = *self; - - let nulls = nulls.build(); - if !NULLABLE { - assert!(nulls.is_none(), "unexpected nulls in non nullable input"); - } - - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(group_values), - nulls, - )) - } - - fn take_n(&mut self, n: usize) -> ArrayRef { - let first_n = self.group_values.drain(0..n).collect::>(); - - let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; - - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(first_n), - first_n_nulls, - )) - } -} - -/// An implementation of [`GroupColumn`] for binary and utf8 types. -/// -/// Stores a collection of binary or utf8 group values in a single buffer -/// in a way that allows: -/// -/// 1. Efficient comparison of incoming rows to existing rows -/// 2. Efficient construction of the final output array -pub struct ByteGroupValueBuilder -where - O: OffsetSizeTrait, -{ - output_type: OutputType, - buffer: BufferBuilder, - /// Offsets into `buffer` for each distinct value. These offsets as used - /// directly to create the final `GenericBinaryArray`. The `i`th string is - /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values - /// are stored as a zero length string. - offsets: Vec, - /// Nulls - nulls: MaybeNullBufferBuilder, -} - -impl ByteGroupValueBuilder -where - O: OffsetSizeTrait, -{ - pub fn new(output_type: OutputType) -> Self { - Self { - output_type, - buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), - offsets: vec![O::default()], - nulls: MaybeNullBufferBuilder::new(), - } - } - - fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool - where - B: ByteArrayType, - { - let array = array.as_bytes::(); - self.do_equal_to_inner(lhs_row, array, rhs_row) - } - - fn append_val_inner(&mut self, array: &ArrayRef, row: usize) - where - B: ByteArrayType, - { - let arr = array.as_bytes::(); - if arr.is_null(row) { - self.nulls.append(true); - // nulls need a zero length in the offset buffer - let offset = self.buffer.len(); - self.offsets.push(O::usize_as(offset)); - } else { - self.nulls.append(false); - self.do_append_val_inner(arr, row); - } - } - - fn vectorized_equal_to_inner( - &self, - lhs_rows: &[usize], - array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut [bool], - ) where - B: ByteArrayType, - { - let array = array.as_bytes::(); - - let iter = izip!( - lhs_rows.iter(), - rhs_rows.iter(), - equal_to_results.iter_mut(), - ); - - for (&lhs_row, &rhs_row, equal_to_result) in iter { - // Has found not equal to, don't need to check - if !*equal_to_result { - continue; - } - - *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row); - } - } - - fn vectorized_append_inner(&mut self, array: &ArrayRef, rows: &[usize]) - where - B: ByteArrayType, - { - let arr = array.as_bytes::(); - let null_count = array.null_count(); - let num_rows = array.len(); - let all_null_or_non_null = if null_count == 0 { - Some(true) - } else if null_count == num_rows { - Some(false) - } else { - None - }; - - match all_null_or_non_null { - None => { - for &row in rows { - if arr.is_null(row) { - self.nulls.append(true); - // nulls need a zero length in the offset buffer - let offset = self.buffer.len(); - self.offsets.push(O::usize_as(offset)); - } else { - self.nulls.append(false); - self.do_append_val_inner(arr, row); - } - } - } - - Some(true) => { - self.nulls.append_n(rows.len(), false); - for &row in rows { - self.do_append_val_inner(arr, row); - } - } - - Some(false) => { - self.nulls.append_n(rows.len(), true); - - let new_len = self.offsets.len() + rows.len(); - let offset = self.buffer.len(); - self.offsets.resize(new_len, O::usize_as(offset)); - } - } - } - - fn do_equal_to_inner( - &self, - lhs_row: usize, - array: &GenericByteArray, - rhs_row: usize, - ) -> bool - where - B: ByteArrayType, - { - let exist_null = self.nulls.is_null(lhs_row); - let input_null = array.is_null(rhs_row); - if let Some(result) = nulls_equal_to(exist_null, input_null) { - return result; - } - // Otherwise, we need to check their values - self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8]) - } - - fn do_append_val_inner(&mut self, array: &GenericByteArray, row: usize) - where - B: ByteArrayType, - { - let value: &[u8] = array.value(row).as_ref(); - self.buffer.append_slice(value); - self.offsets.push(O::usize_as(self.buffer.len())); - } - - /// return the current value of the specified row irrespective of null - pub fn value(&self, row: usize) -> &[u8] { - let l = self.offsets[row].as_usize(); - let r = self.offsets[row + 1].as_usize(); - // Safety: the offsets are constructed correctly and never decrease - unsafe { self.buffer.as_slice().get_unchecked(l..r) } - } -} - -impl GroupColumn for ByteGroupValueBuilder -where - O: OffsetSizeTrait, -{ - fn equal_to(&self, lhs_row: usize, column: &ArrayRef, rhs_row: usize) -> bool { - // Sanity array type - match self.output_type { - OutputType::Binary => { - debug_assert!(matches!( - column.data_type(), - DataType::Binary | DataType::LargeBinary - )); - self.equal_to_inner::>(lhs_row, column, rhs_row) - } - OutputType::Utf8 => { - debug_assert!(matches!( - column.data_type(), - DataType::Utf8 | DataType::LargeUtf8 - )); - self.equal_to_inner::>(lhs_row, column, rhs_row) - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - } - } - - fn append_val(&mut self, column: &ArrayRef, row: usize) { - // Sanity array type - match self.output_type { - OutputType::Binary => { - debug_assert!(matches!( - column.data_type(), - DataType::Binary | DataType::LargeBinary - )); - self.append_val_inner::>(column, row) - } - OutputType::Utf8 => { - debug_assert!(matches!( - column.data_type(), - DataType::Utf8 | DataType::LargeUtf8 - )); - self.append_val_inner::>(column, row) - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - }; - } - - fn vectorized_equal_to( - &self, - lhs_rows: &[usize], - array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut [bool], - ) { - // Sanity array type - match self.output_type { - OutputType::Binary => { - debug_assert!(matches!( - array.data_type(), - DataType::Binary | DataType::LargeBinary - )); - self.vectorized_equal_to_inner::>( - lhs_rows, - array, - rhs_rows, - equal_to_results, - ); - } - OutputType::Utf8 => { - debug_assert!(matches!( - array.data_type(), - DataType::Utf8 | DataType::LargeUtf8 - )); - self.vectorized_equal_to_inner::>( - lhs_rows, - array, - rhs_rows, - equal_to_results, - ); - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - } - } - - fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) { - match self.output_type { - OutputType::Binary => { - debug_assert!(matches!( - column.data_type(), - DataType::Binary | DataType::LargeBinary - )); - self.vectorized_append_inner::>(column, rows) - } - OutputType::Utf8 => { - debug_assert!(matches!( - column.data_type(), - DataType::Utf8 | DataType::LargeUtf8 - )); - self.vectorized_append_inner::>(column, rows) - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - }; - } - - fn len(&self) -> usize { - self.offsets.len() - 1 - } - - fn size(&self) -> usize { - self.buffer.capacity() * size_of::() - + self.offsets.allocated_size() - + self.nulls.allocated_size() - } - - fn build(self: Box) -> ArrayRef { - let Self { - output_type, - mut buffer, - offsets, - nulls, - } = *self; - - let null_buffer = nulls.build(); - - // SAFETY: the offsets were constructed correctly in `insert_if_new` -- - // monotonically increasing, overflows were checked. - let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; - let values = buffer.finish(); - match output_type { - OutputType::Binary => { - // SAFETY: the offsets were constructed correctly - Arc::new(unsafe { - GenericBinaryArray::new_unchecked(offsets, values, null_buffer) - }) - } - OutputType::Utf8 => { - // SAFETY: - // 1. the offsets were constructed safely - // - // 2. the input arrays were all the correct type and thus since - // all the values that went in were valid (e.g. utf8) so are all - // the values that come out - Arc::new(unsafe { - GenericStringArray::new_unchecked(offsets, values, null_buffer) - }) - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - } - } - - fn take_n(&mut self, n: usize) -> ArrayRef { - debug_assert!(self.len() >= n); - let null_buffer = self.nulls.take_n(n); - let first_remaining_offset = O::as_usize(self.offsets[n]); - - // Given offests like [0, 2, 4, 5] and n = 1, we expect to get - // offsets [0, 2, 3]. We first create two offsets for first_n as [0, 2] and the remaining as [2, 4, 5]. - // And we shift the offset starting from 0 for the remaining one, [2, 4, 5] -> [0, 2, 3]. - let mut first_n_offsets = self.offsets.drain(0..n).collect::>(); - let offset_n = *self.offsets.first().unwrap(); - self.offsets - .iter_mut() - .for_each(|offset| *offset = offset.sub(offset_n)); - first_n_offsets.push(offset_n); - - // SAFETY: the offsets were constructed correctly in `insert_if_new` -- - // monotonically increasing, overflows were checked. - let offsets = - unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(first_n_offsets)) }; - - let mut remaining_buffer = - BufferBuilder::new(self.buffer.len() - first_remaining_offset); - // TODO: Current approach copy the remaining and truncate the original one - // Find out a way to avoid copying buffer but split the original one into two. - remaining_buffer.append_slice(&self.buffer.as_slice()[first_remaining_offset..]); - self.buffer.truncate(first_remaining_offset); - let values = self.buffer.finish(); - self.buffer = remaining_buffer; - - match self.output_type { - OutputType::Binary => { - // SAFETY: the offsets were constructed correctly - Arc::new(unsafe { - GenericBinaryArray::new_unchecked(offsets, values, null_buffer) - }) - } - OutputType::Utf8 => { - // SAFETY: - // 1. the offsets were constructed safely - // - // 2. we asserted the input arrays were all the correct type and - // thus since all the values that went in were valid (e.g. utf8) - // so are all the values that come out - Arc::new(unsafe { - GenericStringArray::new_unchecked(offsets, values, null_buffer) - }) - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - } - } -} - -/// An implementation of [`GroupColumn`] for binary view and utf8 view types. -/// -/// Stores a collection of binary view or utf8 view group values in a buffer -/// whose structure is similar to `GenericByteViewArray`, and we can get benefits: -/// -/// 1. Efficient comparison of incoming rows to existing rows -/// 2. Efficient construction of the final output array -/// 3. Efficient to perform `take_n` comparing to use `GenericByteViewBuilder` -pub struct ByteViewGroupValueBuilder { - /// The views of string values - /// - /// If string len <= 12, the view's format will be: - /// string(12B) | len(4B) - /// - /// If string len > 12, its format will be: - /// offset(4B) | buffer_index(4B) | prefix(4B) | len(4B) - views: Vec, - - /// The progressing block - /// - /// New values will be inserted into it until its capacity - /// is not enough(detail can see `max_block_size`). - in_progress: Vec, - - /// The completed blocks - completed: Vec, - - /// The max size of `in_progress` - /// - /// `in_progress` will be flushed into `completed`, and create new `in_progress` - /// when found its remaining capacity(`max_block_size` - `len(in_progress)`), - /// is no enough to store the appended value. - /// - /// Currently it is fixed at 2MB. - max_block_size: usize, - - /// Nulls - nulls: MaybeNullBufferBuilder, - - /// phantom data so the type requires `` - _phantom: PhantomData, -} - -impl ByteViewGroupValueBuilder { - pub fn new() -> Self { - Self { - views: Vec::new(), - in_progress: Vec::new(), - completed: Vec::new(), - max_block_size: BYTE_VIEW_MAX_BLOCK_SIZE, - nulls: MaybeNullBufferBuilder::new(), - _phantom: PhantomData {}, - } - } - - /// Set the max block size - fn with_max_block_size(mut self, max_block_size: usize) -> Self { - self.max_block_size = max_block_size; - self - } - - fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { - let array = array.as_byte_view::(); - self.do_equal_to_inner(lhs_row, array, rhs_row) - } - - fn append_val_inner(&mut self, array: &ArrayRef, row: usize) { - let arr = array.as_byte_view::(); - - // Null row case, set and return - if arr.is_null(row) { - self.nulls.append(true); - self.views.push(0); - return; - } - - // Not null row case - self.nulls.append(false); - self.do_append_val_inner(arr, row); - } - - fn vectorized_equal_to_inner( - &self, - lhs_rows: &[usize], - array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut [bool], - ) { - let array = array.as_byte_view::(); - - let iter = izip!( - lhs_rows.iter(), - rhs_rows.iter(), - equal_to_results.iter_mut(), - ); - - for (&lhs_row, &rhs_row, equal_to_result) in iter { - // Has found not equal to, don't need to check - if !*equal_to_result { - continue; - } - - *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row); - } - } - - fn vectorized_append_inner(&mut self, array: &ArrayRef, rows: &[usize]) { - let arr = array.as_byte_view::(); - let null_count = array.null_count(); - let num_rows = array.len(); - let all_null_or_non_null = if null_count == 0 { - Some(true) - } else if null_count == num_rows { - Some(false) - } else { - None - }; - - match all_null_or_non_null { - None => { - for &row in rows { - // Null row case, set and return - if arr.is_valid(row) { - self.nulls.append(false); - self.do_append_val_inner(arr, row); - } else { - self.nulls.append(true); - self.views.push(0); - } - } - } - - Some(true) => { - self.nulls.append_n(rows.len(), false); - for &row in rows { - self.do_append_val_inner(arr, row); - } - } - - Some(false) => { - self.nulls.append_n(rows.len(), true); - let new_len = self.views.len() + rows.len(); - self.views.resize(new_len, 0); - } - } - } - - fn do_append_val_inner(&mut self, array: &GenericByteViewArray, row: usize) - where - B: ByteViewType, - { - let value: &[u8] = array.value(row).as_ref(); - - let value_len = value.len(); - let view = if value_len <= 12 { - make_view(value, 0, 0) - } else { - // Ensure big enough block to hold the value firstly - self.ensure_in_progress_big_enough(value_len); - - // Append value - let buffer_index = self.completed.len(); - let offset = self.in_progress.len(); - self.in_progress.extend_from_slice(value); - - make_view(value, buffer_index as u32, offset as u32) - }; - - // Append view - self.views.push(view); - } - - fn ensure_in_progress_big_enough(&mut self, value_len: usize) { - debug_assert!(value_len > 12); - let require_cap = self.in_progress.len() + value_len; - - // If current block isn't big enough, flush it and create a new in progress block - if require_cap > self.max_block_size { - let flushed_block = replace( - &mut self.in_progress, - Vec::with_capacity(self.max_block_size), - ); - let buffer = Buffer::from_vec(flushed_block); - self.completed.push(buffer); - } - } - - fn do_equal_to_inner( - &self, - lhs_row: usize, - array: &GenericByteViewArray, - rhs_row: usize, - ) -> bool { - // Check if nulls equal firstly - let exist_null = self.nulls.is_null(lhs_row); - let input_null = array.is_null(rhs_row); - if let Some(result) = nulls_equal_to(exist_null, input_null) { - return result; - } - - // Otherwise, we need to check their values - let exist_view = self.views[lhs_row]; - let exist_view_len = exist_view as u32; - - let input_view = array.views()[rhs_row]; - let input_view_len = input_view as u32; - - // The check logic - // - Check len equality - // - If inlined, check inlined value - // - If non-inlined, check prefix and then check value in buffer - // when needed - if exist_view_len != input_view_len { - return false; - } - - if exist_view_len <= 12 { - let exist_inline = unsafe { - GenericByteViewArray::::inline_value( - &exist_view, - exist_view_len as usize, - ) - }; - let input_inline = unsafe { - GenericByteViewArray::::inline_value( - &input_view, - input_view_len as usize, - ) - }; - exist_inline == input_inline - } else { - let exist_prefix = - unsafe { GenericByteViewArray::::inline_value(&exist_view, 4) }; - let input_prefix = - unsafe { GenericByteViewArray::::inline_value(&input_view, 4) }; - - if exist_prefix != input_prefix { - return false; - } - - let exist_full = { - let byte_view = ByteView::from(exist_view); - self.value( - byte_view.buffer_index as usize, - byte_view.offset as usize, - byte_view.length as usize, - ) - }; - let input_full: &[u8] = unsafe { array.value_unchecked(rhs_row).as_ref() }; - exist_full == input_full - } - } - - fn value(&self, buffer_index: usize, offset: usize, length: usize) -> &[u8] { - debug_assert!(buffer_index <= self.completed.len()); - - if buffer_index < self.completed.len() { - let block = &self.completed[buffer_index]; - &block[offset..offset + length] - } else { - &self.in_progress[offset..offset + length] - } - } - - fn build_inner(self) -> ArrayRef { - let Self { - views, - in_progress, - mut completed, - nulls, - .. - } = self; - - // Build nulls - let null_buffer = nulls.build(); - - // Build values - // Flush `in_process` firstly - if !in_progress.is_empty() { - let buffer = Buffer::from(in_progress); - completed.push(buffer); - } - - let views = ScalarBuffer::from(views); - - // Safety: - // * all views were correctly made - // * (if utf8): Input was valid Utf8 so buffer contents are - // valid utf8 as well - unsafe { - Arc::new(GenericByteViewArray::::new_unchecked( - views, - completed, - null_buffer, - )) - } - } - - fn take_n_inner(&mut self, n: usize) -> ArrayRef { - debug_assert!(self.len() >= n); - - // The `n == len` case, we need to take all - if self.len() == n { - let new_builder = Self::new().with_max_block_size(self.max_block_size); - let cur_builder = replace(self, new_builder); - return cur_builder.build_inner(); - } - - // The `n < len` case - // Take n for nulls - let null_buffer = self.nulls.take_n(n); - - // Take n for values: - // - Take first n `view`s from `views` - // - // - Find the last non-inlined `view`, if all inlined, - // we can build array and return happily, otherwise we - // we need to continue to process related buffers - // - // - Get the last related `buffer index`(let's name it `buffer index n`) - // from last non-inlined `view` - // - // - Take buffers, the key is that we need to know if we need to take - // the whole last related buffer. The logic is a bit complex, you can - // detail in `take_buffers_with_whole_last`, `take_buffers_with_partial_last` - // and other related steps in following - // - // - Shift the `buffer index` of remaining non-inlined `views` - // - let first_n_views = self.views.drain(0..n).collect::>(); - - let last_non_inlined_view = first_n_views - .iter() - .rev() - .find(|view| ((**view) as u32) > 12); - - // All taken views inlined - let Some(view) = last_non_inlined_view else { - let views = ScalarBuffer::from(first_n_views); - - // Safety: - // * all views were correctly made - // * (if utf8): Input was valid Utf8 so buffer contents are - // valid utf8 as well - unsafe { - return Arc::new(GenericByteViewArray::::new_unchecked( - views, - Vec::new(), - null_buffer, - )); - } - }; - - // Unfortunately, some taken views non-inlined - let view = ByteView::from(*view); - let last_remaining_buffer_index = view.buffer_index as usize; - - // Check should we take the whole `last_remaining_buffer_index` buffer - let take_whole_last_buffer = self.should_take_whole_buffer( - last_remaining_buffer_index, - (view.offset + view.length) as usize, - ); - - // Take related buffers - let buffers = if take_whole_last_buffer { - self.take_buffers_with_whole_last(last_remaining_buffer_index) - } else { - self.take_buffers_with_partial_last( - last_remaining_buffer_index, - (view.offset + view.length) as usize, - ) - }; - - // Shift `buffer index`s finally - let shifts = if take_whole_last_buffer { - last_remaining_buffer_index + 1 - } else { - last_remaining_buffer_index - }; - - self.views.iter_mut().for_each(|view| { - if (*view as u32) > 12 { - let mut byte_view = ByteView::from(*view); - byte_view.buffer_index -= shifts as u32; - *view = byte_view.as_u128(); - } - }); - - // Build array and return - let views = ScalarBuffer::from(first_n_views); - - // Safety: - // * all views were correctly made - // * (if utf8): Input was valid Utf8 so buffer contents are - // valid utf8 as well - unsafe { - Arc::new(GenericByteViewArray::::new_unchecked( - views, - buffers, - null_buffer, - )) - } - } - - fn take_buffers_with_whole_last( - &mut self, - last_remaining_buffer_index: usize, - ) -> Vec { - if last_remaining_buffer_index == self.completed.len() { - self.flush_in_progress(); - } - self.completed - .drain(0..last_remaining_buffer_index + 1) - .collect() - } - - fn take_buffers_with_partial_last( - &mut self, - last_remaining_buffer_index: usize, - last_take_len: usize, - ) -> Vec { - let mut take_buffers = Vec::with_capacity(last_remaining_buffer_index + 1); - - // Take `0 ~ last_remaining_buffer_index - 1` buffers - if !self.completed.is_empty() || last_remaining_buffer_index == 0 { - take_buffers.extend(self.completed.drain(0..last_remaining_buffer_index)); - } - - // Process the `last_remaining_buffer_index` buffers - let last_buffer = if last_remaining_buffer_index < self.completed.len() { - // If it is in `completed`, simply clone - self.completed[last_remaining_buffer_index].clone() - } else { - // If it is `in_progress`, copied `0 ~ offset` part - let taken_last_buffer = self.in_progress[0..last_take_len].to_vec(); - Buffer::from_vec(taken_last_buffer) - }; - take_buffers.push(last_buffer); - - take_buffers - } - - #[inline] - fn should_take_whole_buffer(&self, buffer_index: usize, take_len: usize) -> bool { - if buffer_index < self.completed.len() { - take_len == self.completed[buffer_index].len() - } else { - take_len == self.in_progress.len() - } - } - - fn flush_in_progress(&mut self) { - let flushed_block = replace( - &mut self.in_progress, - Vec::with_capacity(self.max_block_size), - ); - let buffer = Buffer::from_vec(flushed_block); - self.completed.push(buffer); - } -} - -impl GroupColumn for ByteViewGroupValueBuilder { - fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { - self.equal_to_inner(lhs_row, array, rhs_row) - } - - fn append_val(&mut self, array: &ArrayRef, row: usize) { - self.append_val_inner(array, row) - } - - fn vectorized_equal_to( - &self, - group_indices: &[usize], - array: &ArrayRef, - rows: &[usize], - equal_to_results: &mut [bool], - ) { - self.vectorized_equal_to_inner(group_indices, array, rows, equal_to_results); - } - - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { - self.vectorized_append_inner(array, rows); - } - - fn len(&self) -> usize { - self.views.len() - } - - fn size(&self) -> usize { - let buffers_size = self - .completed - .iter() - .map(|buf| buf.capacity() * size_of::()) - .sum::(); - - self.nulls.allocated_size() - + self.views.capacity() * size_of::() - + self.in_progress.capacity() * size_of::() - + buffers_size - + size_of::() - } - - fn build(self: Box) -> ArrayRef { - Self::build_inner(*self) - } - - fn take_n(&mut self, n: usize) -> ArrayRef { - self.take_n_inner(n) - } -} - -/// Determines if the nullability of the existing and new input array can be used -/// to short-circuit the comparison of the two values. -/// -/// Returns `Some(result)` if the result of the comparison can be determined -/// from the nullness of the two values, and `None` if the comparison must be -/// done on the values themselves. -fn nulls_equal_to(lhs_null: bool, rhs_null: bool) -> Option { - match (lhs_null, rhs_null) { - (true, true) => Some(true), - (false, true) | (true, false) => Some(false), - _ => None, - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::{ - array::AsArray, - datatypes::{Int64Type, StringViewType}, - }; - use arrow_array::{Array, ArrayRef, Int64Array, StringArray, StringViewArray}; - use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; - use datafusion_physical_expr::binary_map::OutputType; - - use crate::aggregates::group_values::group_column::{ - ByteViewGroupValueBuilder, PrimitiveGroupValueBuilder, - }; - - use super::{ByteGroupValueBuilder, GroupColumn}; - - // ======================================================================== - // Tests for primitive builders - // ======================================================================== - #[test] - fn test_nullable_primitive_equal_to() { - let append = |builder: &mut PrimitiveGroupValueBuilder, - builder_array: &ArrayRef, - append_rows: &[usize]| { - for &index in append_rows { - builder.append_val(builder_array, index); - } - }; - - let equal_to = |builder: &PrimitiveGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - let iter = lhs_rows.iter().zip(rhs_rows.iter()); - for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { - equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); - } - }; - - test_nullable_primitive_equal_to_internal(append, equal_to); - } - - #[test] - fn test_nullable_primitive_vectorized_equal_to() { - let append = |builder: &mut PrimitiveGroupValueBuilder, - builder_array: &ArrayRef, - append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); - }; - - let equal_to = |builder: &PrimitiveGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - builder.vectorized_equal_to( - lhs_rows, - input_array, - rhs_rows, - equal_to_results, - ); - }; - - test_nullable_primitive_equal_to_internal(append, equal_to); - } - - fn test_nullable_primitive_equal_to_internal(mut append: A, mut equal_to: E) - where - A: FnMut(&mut PrimitiveGroupValueBuilder, &ArrayRef, &[usize]), - E: FnMut( - &PrimitiveGroupValueBuilder, - &[usize], - &ArrayRef, - &[usize], - &mut Vec, - ), - { - // Will cover such cases: - // - exist null, input not null - // - exist null, input null; values not equal - // - exist null, input null; values equal - // - exist not null, input null - // - exist not null, input not null; values not equal - // - exist not null, input not null; values equal - - // Define PrimitiveGroupValueBuilder - let mut builder = PrimitiveGroupValueBuilder::::new(); - let builder_array = Arc::new(Int64Array::from(vec![ - None, - None, - None, - Some(1), - Some(2), - Some(3), - ])) as ArrayRef; - append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]); - - // Define input array - let (_nulls, values, _) = - Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)]) - .into_parts(); - - // explicitly build a boolean buffer where one of the null values also happens to match - let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(false); // this sets Some(2) to null above - boolean_buffer_builder.append(false); - boolean_buffer_builder.append(false); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(true); - let nulls = NullBuffer::new(boolean_buffer_builder.finish()); - let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as ArrayRef; - - // Check - let mut equal_to_results = vec![true; builder.len()]; - equal_to( - &builder, - &[0, 1, 2, 3, 4, 5], - &input_array, - &[0, 1, 2, 3, 4, 5], - &mut equal_to_results, - ); - - assert!(!equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(!equal_to_results[3]); - assert!(!equal_to_results[4]); - assert!(equal_to_results[5]); - } - - #[test] - fn test_not_nullable_primitive_equal_to() { - let append = |builder: &mut PrimitiveGroupValueBuilder, - builder_array: &ArrayRef, - append_rows: &[usize]| { - for &index in append_rows { - builder.append_val(builder_array, index); - } - }; - - let equal_to = |builder: &PrimitiveGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - let iter = lhs_rows.iter().zip(rhs_rows.iter()); - for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { - equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); - } - }; - - test_not_nullable_primitive_equal_to_internal(append, equal_to); - } - - #[test] - fn test_not_nullable_primitive_vectorized_equal_to() { - let append = |builder: &mut PrimitiveGroupValueBuilder, - builder_array: &ArrayRef, - append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); - }; - - let equal_to = |builder: &PrimitiveGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - builder.vectorized_equal_to( - lhs_rows, - input_array, - rhs_rows, - equal_to_results, - ); - }; - - test_not_nullable_primitive_equal_to_internal(append, equal_to); - } - - fn test_not_nullable_primitive_equal_to_internal(mut append: A, mut equal_to: E) - where - A: FnMut(&mut PrimitiveGroupValueBuilder, &ArrayRef, &[usize]), - E: FnMut( - &PrimitiveGroupValueBuilder, - &[usize], - &ArrayRef, - &[usize], - &mut Vec, - ), - { - // Will cover such cases: - // - values equal - // - values not equal - - // Define PrimitiveGroupValueBuilder - let mut builder = PrimitiveGroupValueBuilder::::new(); - let builder_array = - Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef; - append(&mut builder, &builder_array, &[0, 1]); - - // Define input array - let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef; - - // Check - let mut equal_to_results = vec![true; builder.len()]; - equal_to( - &builder, - &[0, 1], - &input_array, - &[0, 1], - &mut equal_to_results, - ); - - assert!(equal_to_results[0]); - assert!(!equal_to_results[1]); - } - - #[test] - fn test_nullable_primitive_vectorized_operation_special_case() { - // Test the special `all nulls` or `not nulls` input array case - // for vectorized append and equal to - - let mut builder = PrimitiveGroupValueBuilder::::new(); - - // All nulls input array - let all_nulls_input_array = Arc::new(Int64Array::from(vec![ - Option::::None, - None, - None, - None, - None, - ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); - - let mut equal_to_results = vec![true; all_nulls_input_array.len()]; - builder.vectorized_equal_to( - &[0, 1, 2, 3, 4], - &all_nulls_input_array, - &[0, 1, 2, 3, 4], - &mut equal_to_results, - ); - - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); - - // All not nulls input array - let all_not_nulls_input_array = Arc::new(Int64Array::from(vec![ - Some(1), - Some(2), - Some(3), - Some(4), - Some(5), - ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); - - let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; - builder.vectorized_equal_to( - &[5, 6, 7, 8, 9], - &all_not_nulls_input_array, - &[0, 1, 2, 3, 4], - &mut equal_to_results, - ); - - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); - } - - // ======================================================================== - // Tests for byte builders - // ======================================================================== - #[test] - fn test_byte_take_n() { - let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); - let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef; - // a, null, null - builder.append_val(&array, 0); - builder.append_val(&array, 1); - builder.append_val(&array, 1); - - // (a, null) remaining: null - let output = builder.take_n(2); - assert_eq!(&output, &array); - - // null, a, null, a - builder.append_val(&array, 0); - builder.append_val(&array, 1); - builder.append_val(&array, 0); - - // (null, a) remaining: (null, a) - let output = builder.take_n(2); - let array = Arc::new(StringArray::from(vec![None, Some("a")])) as ArrayRef; - assert_eq!(&output, &array); - - let array = Arc::new(StringArray::from(vec![ - Some("a"), - None, - Some("longstringfortest"), - ])) as ArrayRef; - - // null, a, longstringfortest, null, null - builder.append_val(&array, 2); - builder.append_val(&array, 1); - builder.append_val(&array, 1); - - // (null, a, longstringfortest, null) remaining: (null) - let output = builder.take_n(4); - let array = Arc::new(StringArray::from(vec![ - None, - Some("a"), - Some("longstringfortest"), - None, - ])) as ArrayRef; - assert_eq!(&output, &array); - } - - #[test] - fn test_byte_equal_to() { - let append = |builder: &mut ByteGroupValueBuilder, - builder_array: &ArrayRef, - append_rows: &[usize]| { - for &index in append_rows { - builder.append_val(builder_array, index); - } - }; - - let equal_to = |builder: &ByteGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - let iter = lhs_rows.iter().zip(rhs_rows.iter()); - for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { - equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); - } - }; - - test_byte_equal_to_internal(append, equal_to); - } - - #[test] - fn test_byte_vectorized_equal_to() { - let append = |builder: &mut ByteGroupValueBuilder, - builder_array: &ArrayRef, - append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); - }; - - let equal_to = |builder: &ByteGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - builder.vectorized_equal_to( - lhs_rows, - input_array, - rhs_rows, - equal_to_results, - ); - }; - - test_byte_equal_to_internal(append, equal_to); - } - - #[test] - fn test_byte_vectorized_operation_special_case() { - // Test the special `all nulls` or `not nulls` input array case - // for vectorized append and equal to - - let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); - - // All nulls input array - let all_nulls_input_array = Arc::new(StringArray::from(vec![ - Option::<&str>::None, - None, - None, - None, - None, - ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); - - let mut equal_to_results = vec![true; all_nulls_input_array.len()]; - builder.vectorized_equal_to( - &[0, 1, 2, 3, 4], - &all_nulls_input_array, - &[0, 1, 2, 3, 4], - &mut equal_to_results, - ); - - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); - - // All not nulls input array - let all_not_nulls_input_array = Arc::new(StringArray::from(vec![ - Some("string1"), - Some("string2"), - Some("string3"), - Some("string4"), - Some("string5"), - ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); - - let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; - builder.vectorized_equal_to( - &[5, 6, 7, 8, 9], - &all_not_nulls_input_array, - &[0, 1, 2, 3, 4], - &mut equal_to_results, - ); - - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); - } - - fn test_byte_equal_to_internal(mut append: A, mut equal_to: E) - where - A: FnMut(&mut ByteGroupValueBuilder, &ArrayRef, &[usize]), - E: FnMut( - &ByteGroupValueBuilder, - &[usize], - &ArrayRef, - &[usize], - &mut Vec, - ), - { - // Will cover such cases: - // - exist null, input not null - // - exist null, input null; values not equal - // - exist null, input null; values equal - // - exist not null, input null - // - exist not null, input not null; values not equal - // - exist not null, input not null; values equal - - // Define PrimitiveGroupValueBuilder - let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); - let builder_array = Arc::new(StringArray::from(vec![ - None, - None, - None, - Some("foo"), - Some("bar"), - Some("baz"), - ])) as ArrayRef; - append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]); - - // Define input array - let (offsets, buffer, _nulls) = StringArray::from(vec![ - Some("foo"), - Some("bar"), - None, - None, - Some("foo"), - Some("baz"), - ]) - .into_parts(); - - // explicitly build a boolean buffer where one of the null values also happens to match - let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(false); // this sets Some("bar") to null above - boolean_buffer_builder.append(false); - boolean_buffer_builder.append(false); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(true); - let nulls = NullBuffer::new(boolean_buffer_builder.finish()); - let input_array = - Arc::new(StringArray::new(offsets, buffer, Some(nulls))) as ArrayRef; - - // Check - let mut equal_to_results = vec![true; builder.len()]; - equal_to( - &builder, - &[0, 1, 2, 3, 4, 5], - &input_array, - &[0, 1, 2, 3, 4, 5], - &mut equal_to_results, - ); - - assert!(!equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(!equal_to_results[3]); - assert!(!equal_to_results[4]); - assert!(equal_to_results[5]); - } - - // ======================================================================== - // Tests for byte view builders - // ======================================================================== - #[test] - fn test_byte_view_append_val() { - let mut builder = - ByteViewGroupValueBuilder::::new().with_max_block_size(60); - let builder_array = StringViewArray::from(vec![ - Some("this string is quite long"), // in buffer 0 - Some("foo"), - None, - Some("bar"), - Some("this string is also quite long"), // buffer 0 - Some("this string is quite long"), // buffer 1 - Some("bar"), - ]); - let builder_array: ArrayRef = Arc::new(builder_array); - for row in 0..builder_array.len() { - builder.append_val(&builder_array, row); - } - - let output = Box::new(builder).build(); - // should be 2 output buffers to hold all the data - assert_eq!(output.as_string_view().data_buffers().len(), 2); - assert_eq!(&output, &builder_array) - } - - #[test] - fn test_byte_view_equal_to() { - let append = |builder: &mut ByteViewGroupValueBuilder, - builder_array: &ArrayRef, - append_rows: &[usize]| { - for &index in append_rows { - builder.append_val(builder_array, index); - } - }; - - let equal_to = |builder: &ByteViewGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - let iter = lhs_rows.iter().zip(rhs_rows.iter()); - for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { - equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); - } - }; - - test_byte_view_equal_to_internal(append, equal_to); - } - - #[test] - fn test_byte_view_vectorized_equal_to() { - let append = |builder: &mut ByteViewGroupValueBuilder, - builder_array: &ArrayRef, - append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); - }; - - let equal_to = |builder: &ByteViewGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - builder.vectorized_equal_to( - lhs_rows, - input_array, - rhs_rows, - equal_to_results, - ); - }; - - test_byte_view_equal_to_internal(append, equal_to); - } - - #[test] - fn test_byte_view_vectorized_operation_special_case() { - // Test the special `all nulls` or `not nulls` input array case - // for vectorized append and equal to - - let mut builder = - ByteViewGroupValueBuilder::::new().with_max_block_size(60); - - // All nulls input array - let all_nulls_input_array = Arc::new(StringViewArray::from(vec![ - Option::<&str>::None, - None, - None, - None, - None, - ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); - - let mut equal_to_results = vec![true; all_nulls_input_array.len()]; - builder.vectorized_equal_to( - &[0, 1, 2, 3, 4], - &all_nulls_input_array, - &[0, 1, 2, 3, 4], - &mut equal_to_results, - ); - - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); - - // All not nulls input array - let all_not_nulls_input_array = Arc::new(StringViewArray::from(vec![ - Some("stringview1"), - Some("stringview2"), - Some("stringview3"), - Some("stringview4"), - Some("stringview5"), - ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); - - let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; - builder.vectorized_equal_to( - &[5, 6, 7, 8, 9], - &all_not_nulls_input_array, - &[0, 1, 2, 3, 4], - &mut equal_to_results, - ); - - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); - } - - fn test_byte_view_equal_to_internal(mut append: A, mut equal_to: E) - where - A: FnMut(&mut ByteViewGroupValueBuilder, &ArrayRef, &[usize]), - E: FnMut( - &ByteViewGroupValueBuilder, - &[usize], - &ArrayRef, - &[usize], - &mut Vec, - ), - { - // Will cover such cases: - // - exist null, input not null - // - exist null, input null; values not equal - // - exist null, input null; values equal - // - exist not null, input null - // - exist not null, input not null; value lens not equal - // - exist not null, input not null; value not equal(inlined case) - // - exist not null, input not null; value equal(inlined case) - // - // - exist not null, input not null; value not equal - // (non-inlined case + prefix not equal) - // - // - exist not null, input not null; value not equal - // (non-inlined case + value in `completed`) - // - // - exist not null, input not null; value equal - // (non-inlined case + value in `completed`) - // - // - exist not null, input not null; value not equal - // (non-inlined case + value in `in_progress`) - // - // - exist not null, input not null; value equal - // (non-inlined case + value in `in_progress`) - - // Set the block size to 40 for ensuring some unlined values are in `in_progress`, - // and some are in `completed`, so both two branches in `value` function can be covered. - let mut builder = - ByteViewGroupValueBuilder::::new().with_max_block_size(60); - let builder_array = Arc::new(StringViewArray::from(vec![ - None, - None, - None, - Some("foo"), - Some("bazz"), - Some("foo"), - Some("bar"), - Some("I am a long string for test eq in completed"), - Some("I am a long string for test eq in progress"), - ])) as ArrayRef; - append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5, 6, 7, 8]); - - // Define input array - let (views, buffer, _nulls) = StringViewArray::from(vec![ - Some("foo"), - Some("bar"), // set to null - None, - None, - Some("baz"), - Some("oof"), - Some("bar"), - Some("i am a long string for test eq in completed"), - Some("I am a long string for test eq in COMPLETED"), - Some("I am a long string for test eq in completed"), - Some("I am a long string for test eq in PROGRESS"), - Some("I am a long string for test eq in progress"), - ]) - .into_parts(); - - // explicitly build a boolean buffer where one of the null values also happens to match - let mut boolean_buffer_builder = BooleanBufferBuilder::new(9); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(false); // this sets Some("bar") to null above - boolean_buffer_builder.append(false); - boolean_buffer_builder.append(false); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(true); - boolean_buffer_builder.append(true); - let nulls = NullBuffer::new(boolean_buffer_builder.finish()); - let input_array = - Arc::new(StringViewArray::new(views, buffer, Some(nulls))) as ArrayRef; - - // Check - let mut equal_to_results = vec![true; input_array.len()]; - equal_to( - &builder, - &[0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 8, 8], - &input_array, - &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - &mut equal_to_results, - ); - - assert!(!equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(!equal_to_results[3]); - assert!(!equal_to_results[4]); - assert!(!equal_to_results[5]); - assert!(equal_to_results[6]); - assert!(!equal_to_results[7]); - assert!(!equal_to_results[8]); - assert!(equal_to_results[9]); - assert!(!equal_to_results[10]); - assert!(equal_to_results[11]); - } - - #[test] - fn test_byte_view_take_n() { - // ####### Define cases and init ####### - - // `take_n` is really complex, we should consider and test following situations: - // 1. Take nulls - // 2. Take all `inlined`s - // 3. Take non-inlined + partial last buffer in `completed` - // 4. Take non-inlined + whole last buffer in `completed` - // 5. Take non-inlined + partial last `in_progress` - // 6. Take non-inlined + whole last buffer in `in_progress` - // 7. Take all views at once - - let mut builder = - ByteViewGroupValueBuilder::::new().with_max_block_size(60); - let input_array = StringViewArray::from(vec![ - // Test situation 1 - None, - None, - // Test situation 2 (also test take null together) - None, - Some("foo"), - Some("bar"), - // Test situation 3 (also test take null + inlined) - None, - Some("foo"), - Some("this string is quite long"), - Some("this string is also quite long"), - // Test situation 4 (also test take null + inlined) - None, - Some("bar"), - Some("this string is quite long"), - // Test situation 5 (also test take null + inlined) - None, - Some("foo"), - Some("another string that is is quite long"), - Some("this string not so long"), - // Test situation 6 (also test take null + inlined + insert again after taking) - None, - Some("bar"), - Some("this string is quite long"), - // Insert 4 and just take 3 to ensure it will go the path of situation 6 - None, - // Finally, we create a new builder, insert the whole array and then - // take whole at once for testing situation 7 - ]); - - let input_array: ArrayRef = Arc::new(input_array); - let first_ones_to_append = 16; // For testing situation 1~5 - let second_ones_to_append = 4; // For testing situation 6 - let final_ones_to_append = input_array.len(); // For testing situation 7 - - // ####### Test situation 1~5 ####### - for row in 0..first_ones_to_append { - builder.append_val(&input_array, row); - } - - assert_eq!(builder.completed.len(), 2); - assert_eq!(builder.in_progress.len(), 59); - - // Situation 1 - let taken_array = builder.take_n(2); - assert_eq!(&taken_array, &input_array.slice(0, 2)); - - // Situation 2 - let taken_array = builder.take_n(3); - assert_eq!(&taken_array, &input_array.slice(2, 3)); - - // Situation 3 - let taken_array = builder.take_n(3); - assert_eq!(&taken_array, &input_array.slice(5, 3)); - - let taken_array = builder.take_n(1); - assert_eq!(&taken_array, &input_array.slice(8, 1)); - - // Situation 4 - let taken_array = builder.take_n(3); - assert_eq!(&taken_array, &input_array.slice(9, 3)); - - // Situation 5 - let taken_array = builder.take_n(3); - assert_eq!(&taken_array, &input_array.slice(12, 3)); - - let taken_array = builder.take_n(1); - assert_eq!(&taken_array, &input_array.slice(15, 1)); - - // ####### Test situation 6 ####### - assert!(builder.completed.is_empty()); - assert!(builder.in_progress.is_empty()); - assert!(builder.views.is_empty()); - - for row in first_ones_to_append..first_ones_to_append + second_ones_to_append { - builder.append_val(&input_array, row); - } - - assert!(builder.completed.is_empty()); - assert_eq!(builder.in_progress.len(), 25); - - let taken_array = builder.take_n(3); - assert_eq!(&taken_array, &input_array.slice(16, 3)); - - // ####### Test situation 7 ####### - // Create a new builder - let mut builder = - ByteViewGroupValueBuilder::::new().with_max_block_size(60); - - for row in 0..final_ones_to_append { - builder.append_val(&input_array, row); - } - - assert_eq!(builder.completed.len(), 3); - assert_eq!(builder.in_progress.len(), 25); - - let taken_array = builder.take_n(final_ones_to_append); - assert_eq!(&taken_array, &input_array); - } -} diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index aefd9c162246..12ed25a0ea34 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -27,9 +27,9 @@ pub(crate) mod primitive; use datafusion_expr::EmitTo; use primitive::GroupValuesPrimitive; -mod column; +mod multi_column; mod row; -use column::GroupValuesColumn; +use multi_column::GroupValuesColumn; use row::GroupValuesRows; mod bytes; @@ -39,7 +39,6 @@ use datafusion_physical_expr::binary_map::OutputType; use crate::aggregates::order::GroupOrdering; -mod group_column; mod null_builder; /// Stores the group values during hash aggregation. @@ -148,7 +147,7 @@ pub fn new_group_values( } } - if column::supported_schema(schema.as_ref()) { + if multi_column::supported_schema(schema.as_ref()) { if matches!(group_ordering, GroupOrdering::None) { Ok(Box::new(GroupValuesColumn::::try_new(schema)?)) } else { diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_column/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_column/bytes.rs new file mode 100644 index 000000000000..820d28fc58e7 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_column/bytes.rs @@ -0,0 +1,633 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::multi_column::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow::array::{AsArray, BufferBuilder, GenericBinaryArray, GenericStringArray}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType}; +use arrow_array::types::GenericStringType; +use arrow_array::{Array, ArrayRef, GenericByteArray, OffsetSizeTrait}; +use datafusion_common::utils::proxy::VecAllocExt; +use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; +use itertools::izip; +use std::mem::size_of; +use std::sync::Arc; +use std::vec; + +/// An implementation of [`GroupColumn`] for binary and utf8 types. +/// +/// Stores a collection of binary or utf8 group values in a single buffer +/// in a way that allows: +/// +/// 1. Efficient comparison of incoming rows to existing rows +/// 2. Efficient construction of the final output array +pub struct ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + output_type: OutputType, + buffer: BufferBuilder, + /// Offsets into `buffer` for each distinct value. These offsets as used + /// directly to create the final `GenericBinaryArray`. The `i`th string is + /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values + /// are stored as a zero length string. + offsets: Vec, + /// Nulls + nulls: MaybeNullBufferBuilder, +} + +impl ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + pub fn new(output_type: OutputType) -> Self { + Self { + output_type, + buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), + offsets: vec![O::default()], + nulls: MaybeNullBufferBuilder::new(), + } + } + + fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool + where + B: ByteArrayType, + { + let array = array.as_bytes::(); + self.do_equal_to_inner(lhs_row, array, rhs_row) + } + + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) + where + B: ByteArrayType, + { + let arr = array.as_bytes::(); + if arr.is_null(row) { + self.nulls.append(true); + // nulls need a zero length in the offset buffer + let offset = self.buffer.len(); + self.offsets.push(O::usize_as(offset)); + } else { + self.nulls.append(false); + self.do_append_val_inner(arr, row); + } + } + + fn vectorized_equal_to_inner( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) where + B: ByteArrayType, + { + let array = array.as_bytes::(); + + let iter = izip!( + lhs_rows.iter(), + rhs_rows.iter(), + equal_to_results.iter_mut(), + ); + + for (&lhs_row, &rhs_row, equal_to_result) in iter { + // Has found not equal to, don't need to check + if !*equal_to_result { + continue; + } + + *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row); + } + } + + fn vectorized_append_inner(&mut self, array: &ArrayRef, rows: &[usize]) + where + B: ByteArrayType, + { + let arr = array.as_bytes::(); + let null_count = array.null_count(); + let num_rows = array.len(); + let all_null_or_non_null = if null_count == 0 { + Some(true) + } else if null_count == num_rows { + Some(false) + } else { + None + }; + + match all_null_or_non_null { + None => { + for &row in rows { + if arr.is_null(row) { + self.nulls.append(true); + // nulls need a zero length in the offset buffer + let offset = self.buffer.len(); + self.offsets.push(O::usize_as(offset)); + } else { + self.nulls.append(false); + self.do_append_val_inner(arr, row); + } + } + } + + Some(true) => { + self.nulls.append_n(rows.len(), false); + for &row in rows { + self.do_append_val_inner(arr, row); + } + } + + Some(false) => { + self.nulls.append_n(rows.len(), true); + + let new_len = self.offsets.len() + rows.len(); + let offset = self.buffer.len(); + self.offsets.resize(new_len, O::usize_as(offset)); + } + } + } + + fn do_equal_to_inner( + &self, + lhs_row: usize, + array: &GenericByteArray, + rhs_row: usize, + ) -> bool + where + B: ByteArrayType, + { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + // Otherwise, we need to check their values + self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8]) + } + + fn do_append_val_inner(&mut self, array: &GenericByteArray, row: usize) + where + B: ByteArrayType, + { + let value: &[u8] = array.value(row).as_ref(); + self.buffer.append_slice(value); + self.offsets.push(O::usize_as(self.buffer.len())); + } + + /// return the current value of the specified row irrespective of null + pub fn value(&self, row: usize) -> &[u8] { + let l = self.offsets[row].as_usize(); + let r = self.offsets[row + 1].as_usize(); + // Safety: the offsets are constructed correctly and never decrease + unsafe { self.buffer.as_slice().get_unchecked(l..r) } + } +} + +impl GroupColumn for ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + fn equal_to(&self, lhs_row: usize, column: &ArrayRef, rhs_row: usize) -> bool { + // Sanity array type + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + column.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.equal_to_inner::>(lhs_row, column, rhs_row) + } + OutputType::Utf8 => { + debug_assert!(matches!( + column.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.equal_to_inner::>(lhs_row, column, rhs_row) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + + fn append_val(&mut self, column: &ArrayRef, row: usize) { + // Sanity array type + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + column.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.append_val_inner::>(column, row) + } + OutputType::Utf8 => { + debug_assert!(matches!( + column.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.append_val_inner::>(column, row) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + }; + } + + fn vectorized_equal_to( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) { + // Sanity array type + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + array.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.vectorized_equal_to_inner::>( + lhs_rows, + array, + rhs_rows, + equal_to_results, + ); + } + OutputType::Utf8 => { + debug_assert!(matches!( + array.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.vectorized_equal_to_inner::>( + lhs_rows, + array, + rhs_rows, + equal_to_results, + ); + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + + fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) { + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + column.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.vectorized_append_inner::>(column, rows) + } + OutputType::Utf8 => { + debug_assert!(matches!( + column.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.vectorized_append_inner::>(column, rows) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + }; + } + + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn size(&self) -> usize { + self.buffer.capacity() * size_of::() + + self.offsets.allocated_size() + + self.nulls.allocated_size() + } + + fn build(self: Box) -> ArrayRef { + let Self { + output_type, + mut buffer, + offsets, + nulls, + } = *self; + + let null_buffer = nulls.build(); + + // SAFETY: the offsets were constructed correctly in `insert_if_new` -- + // monotonically increasing, overflows were checked. + let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; + let values = buffer.finish(); + match output_type { + OutputType::Binary => { + // SAFETY: the offsets were constructed correctly + Arc::new(unsafe { + GenericBinaryArray::new_unchecked(offsets, values, null_buffer) + }) + } + OutputType::Utf8 => { + // SAFETY: + // 1. the offsets were constructed safely + // + // 2. the input arrays were all the correct type and thus since + // all the values that went in were valid (e.g. utf8) so are all + // the values that come out + Arc::new(unsafe { + GenericStringArray::new_unchecked(offsets, values, null_buffer) + }) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + debug_assert!(self.len() >= n); + let null_buffer = self.nulls.take_n(n); + let first_remaining_offset = O::as_usize(self.offsets[n]); + + // Given offests like [0, 2, 4, 5] and n = 1, we expect to get + // offsets [0, 2, 3]. We first create two offsets for first_n as [0, 2] and the remaining as [2, 4, 5]. + // And we shift the offset starting from 0 for the remaining one, [2, 4, 5] -> [0, 2, 3]. + let mut first_n_offsets = self.offsets.drain(0..n).collect::>(); + let offset_n = *self.offsets.first().unwrap(); + self.offsets + .iter_mut() + .for_each(|offset| *offset = offset.sub(offset_n)); + first_n_offsets.push(offset_n); + + // SAFETY: the offsets were constructed correctly in `insert_if_new` -- + // monotonically increasing, overflows were checked. + let offsets = + unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(first_n_offsets)) }; + + let mut remaining_buffer = + BufferBuilder::new(self.buffer.len() - first_remaining_offset); + // TODO: Current approach copy the remaining and truncate the original one + // Find out a way to avoid copying buffer but split the original one into two. + remaining_buffer.append_slice(&self.buffer.as_slice()[first_remaining_offset..]); + self.buffer.truncate(first_remaining_offset); + let values = self.buffer.finish(); + self.buffer = remaining_buffer; + + match self.output_type { + OutputType::Binary => { + // SAFETY: the offsets were constructed correctly + Arc::new(unsafe { + GenericBinaryArray::new_unchecked(offsets, values, null_buffer) + }) + } + OutputType::Utf8 => { + // SAFETY: + // 1. the offsets were constructed safely + // + // 2. we asserted the input arrays were all the correct type and + // thus since all the values that went in were valid (e.g. utf8) + // so are all the values that come out + Arc::new(unsafe { + GenericStringArray::new_unchecked(offsets, values, null_buffer) + }) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::aggregates::group_values::multi_column::bytes::ByteGroupValueBuilder; + use arrow_array::{ArrayRef, StringArray}; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + use datafusion_physical_expr::binary_map::OutputType; + + use super::GroupColumn; + + #[test] + fn test_byte_take_n() { + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef; + // a, null, null + builder.append_val(&array, 0); + builder.append_val(&array, 1); + builder.append_val(&array, 1); + + // (a, null) remaining: null + let output = builder.take_n(2); + assert_eq!(&output, &array); + + // null, a, null, a + builder.append_val(&array, 0); + builder.append_val(&array, 1); + builder.append_val(&array, 0); + + // (null, a) remaining: (null, a) + let output = builder.take_n(2); + let array = Arc::new(StringArray::from(vec![None, Some("a")])) as ArrayRef; + assert_eq!(&output, &array); + + let array = Arc::new(StringArray::from(vec![ + Some("a"), + None, + Some("longstringfortest"), + ])) as ArrayRef; + + // null, a, longstringfortest, null, null + builder.append_val(&array, 2); + builder.append_val(&array, 1); + builder.append_val(&array, 1); + + // (null, a, longstringfortest, null) remaining: (null) + let output = builder.take_n(4); + let array = Arc::new(StringArray::from(vec![ + None, + Some("a"), + Some("longstringfortest"), + None, + ])) as ArrayRef; + assert_eq!(&output, &array); + } + + #[test] + fn test_byte_equal_to() { + let append = |builder: &mut ByteGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + for &index in append_rows { + builder.append_val(builder_array, index); + } + }; + + let equal_to = |builder: &ByteGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); + } + }; + + test_byte_equal_to_internal(append, equal_to); + } + + #[test] + fn test_byte_vectorized_equal_to() { + let append = |builder: &mut ByteGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + builder.vectorized_append(builder_array, append_rows); + }; + + let equal_to = |builder: &ByteGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; + + test_byte_equal_to_internal(append, equal_to); + } + + #[test] + fn test_byte_vectorized_operation_special_case() { + // Test the special `all nulls` or `not nulls` input array case + // for vectorized append and equal to + + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + + // All nulls input array + let all_nulls_input_array = Arc::new(StringArray::from(vec![ + Option::<&str>::None, + None, + None, + None, + None, + ])) as _; + builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[0, 1, 2, 3, 4], + &all_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + + // All not nulls input array + let all_not_nulls_input_array = Arc::new(StringArray::from(vec![ + Some("string1"), + Some("string2"), + Some("string3"), + Some("string4"), + Some("string5"), + ])) as _; + builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[5, 6, 7, 8, 9], + &all_not_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + } + + fn test_byte_equal_to_internal(mut append: A, mut equal_to: E) + where + A: FnMut(&mut ByteGroupValueBuilder, &ArrayRef, &[usize]), + E: FnMut( + &ByteGroupValueBuilder, + &[usize], + &ArrayRef, + &[usize], + &mut Vec, + ), + { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + let builder_array = Arc::new(StringArray::from(vec![ + None, + None, + None, + Some("foo"), + Some("bar"), + Some("baz"), + ])) as ArrayRef; + append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]); + + // Define input array + let (offsets, buffer, _nulls) = StringArray::from(vec![ + Some("foo"), + Some("bar"), + None, + None, + Some("foo"), + Some("baz"), + ]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some("bar") to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = + Arc::new(StringArray::new(offsets, buffer, Some(nulls))) as ArrayRef; + + // Check + let mut equal_to_results = vec![true; builder.len()]; + equal_to( + &builder, + &[0, 1, 2, 3, 4, 5], + &input_array, + &[0, 1, 2, 3, 4, 5], + &mut equal_to_results, + ); + + assert!(!equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(!equal_to_results[3]); + assert!(!equal_to_results[4]); + assert!(equal_to_results[5]); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_column/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_column/bytes_view.rs new file mode 100644 index 000000000000..032b4d9e2a91 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_column/bytes_view.rs @@ -0,0 +1,911 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::multi_column::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow::array::{make_view, AsArray, ByteView}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::ByteViewType; +use arrow_array::{Array, ArrayRef, GenericByteViewArray}; +use arrow_buffer::Buffer; +use itertools::izip; +use std::marker::PhantomData; +use std::mem::{replace, size_of}; +use std::sync::Arc; + +const BYTE_VIEW_MAX_BLOCK_SIZE: usize = 2 * 1024 * 1024; + +/// An implementation of [`GroupColumn`] for binary view and utf8 view types. +/// +/// Stores a collection of binary view or utf8 view group values in a buffer +/// whose structure is similar to `GenericByteViewArray`, and we can get benefits: +/// +/// 1. Efficient comparison of incoming rows to existing rows +/// 2. Efficient construction of the final output array +/// 3. Efficient to perform `take_n` comparing to use `GenericByteViewBuilder` +pub struct ByteViewGroupValueBuilder { + /// The views of string values + /// + /// If string len <= 12, the view's format will be: + /// string(12B) | len(4B) + /// + /// If string len > 12, its format will be: + /// offset(4B) | buffer_index(4B) | prefix(4B) | len(4B) + views: Vec, + + /// The progressing block + /// + /// New values will be inserted into it until its capacity + /// is not enough(detail can see `max_block_size`). + in_progress: Vec, + + /// The completed blocks + completed: Vec, + + /// The max size of `in_progress` + /// + /// `in_progress` will be flushed into `completed`, and create new `in_progress` + /// when found its remaining capacity(`max_block_size` - `len(in_progress)`), + /// is no enough to store the appended value. + /// + /// Currently it is fixed at 2MB. + max_block_size: usize, + + /// Nulls + nulls: MaybeNullBufferBuilder, + + /// phantom data so the type requires `` + _phantom: PhantomData, +} + +impl ByteViewGroupValueBuilder { + pub fn new() -> Self { + Self { + views: Vec::new(), + in_progress: Vec::new(), + completed: Vec::new(), + max_block_size: BYTE_VIEW_MAX_BLOCK_SIZE, + nulls: MaybeNullBufferBuilder::new(), + _phantom: PhantomData {}, + } + } + + /// Set the max block size + fn with_max_block_size(mut self, max_block_size: usize) -> Self { + self.max_block_size = max_block_size; + self + } + + fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + let array = array.as_byte_view::(); + self.do_equal_to_inner(lhs_row, array, rhs_row) + } + + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) { + let arr = array.as_byte_view::(); + + // Null row case, set and return + if arr.is_null(row) { + self.nulls.append(true); + self.views.push(0); + return; + } + + // Not null row case + self.nulls.append(false); + self.do_append_val_inner(arr, row); + } + + fn vectorized_equal_to_inner( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) { + let array = array.as_byte_view::(); + + let iter = izip!( + lhs_rows.iter(), + rhs_rows.iter(), + equal_to_results.iter_mut(), + ); + + for (&lhs_row, &rhs_row, equal_to_result) in iter { + // Has found not equal to, don't need to check + if !*equal_to_result { + continue; + } + + *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row); + } + } + + fn vectorized_append_inner(&mut self, array: &ArrayRef, rows: &[usize]) { + let arr = array.as_byte_view::(); + let null_count = array.null_count(); + let num_rows = array.len(); + let all_null_or_non_null = if null_count == 0 { + Some(true) + } else if null_count == num_rows { + Some(false) + } else { + None + }; + + match all_null_or_non_null { + None => { + for &row in rows { + // Null row case, set and return + if arr.is_valid(row) { + self.nulls.append(false); + self.do_append_val_inner(arr, row); + } else { + self.nulls.append(true); + self.views.push(0); + } + } + } + + Some(true) => { + self.nulls.append_n(rows.len(), false); + for &row in rows { + self.do_append_val_inner(arr, row); + } + } + + Some(false) => { + self.nulls.append_n(rows.len(), true); + let new_len = self.views.len() + rows.len(); + self.views.resize(new_len, 0); + } + } + } + + fn do_append_val_inner(&mut self, array: &GenericByteViewArray, row: usize) + where + B: ByteViewType, + { + let value: &[u8] = array.value(row).as_ref(); + + let value_len = value.len(); + let view = if value_len <= 12 { + make_view(value, 0, 0) + } else { + // Ensure big enough block to hold the value firstly + self.ensure_in_progress_big_enough(value_len); + + // Append value + let buffer_index = self.completed.len(); + let offset = self.in_progress.len(); + self.in_progress.extend_from_slice(value); + + make_view(value, buffer_index as u32, offset as u32) + }; + + // Append view + self.views.push(view); + } + + fn ensure_in_progress_big_enough(&mut self, value_len: usize) { + debug_assert!(value_len > 12); + let require_cap = self.in_progress.len() + value_len; + + // If current block isn't big enough, flush it and create a new in progress block + if require_cap > self.max_block_size { + let flushed_block = replace( + &mut self.in_progress, + Vec::with_capacity(self.max_block_size), + ); + let buffer = Buffer::from_vec(flushed_block); + self.completed.push(buffer); + } + } + + fn do_equal_to_inner( + &self, + lhs_row: usize, + array: &GenericByteViewArray, + rhs_row: usize, + ) -> bool { + // Check if nulls equal firstly + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + + // Otherwise, we need to check their values + let exist_view = self.views[lhs_row]; + let exist_view_len = exist_view as u32; + + let input_view = array.views()[rhs_row]; + let input_view_len = input_view as u32; + + // The check logic + // - Check len equality + // - If inlined, check inlined value + // - If non-inlined, check prefix and then check value in buffer + // when needed + if exist_view_len != input_view_len { + return false; + } + + if exist_view_len <= 12 { + let exist_inline = unsafe { + GenericByteViewArray::::inline_value( + &exist_view, + exist_view_len as usize, + ) + }; + let input_inline = unsafe { + GenericByteViewArray::::inline_value( + &input_view, + input_view_len as usize, + ) + }; + exist_inline == input_inline + } else { + let exist_prefix = + unsafe { GenericByteViewArray::::inline_value(&exist_view, 4) }; + let input_prefix = + unsafe { GenericByteViewArray::::inline_value(&input_view, 4) }; + + if exist_prefix != input_prefix { + return false; + } + + let exist_full = { + let byte_view = ByteView::from(exist_view); + self.value( + byte_view.buffer_index as usize, + byte_view.offset as usize, + byte_view.length as usize, + ) + }; + let input_full: &[u8] = unsafe { array.value_unchecked(rhs_row).as_ref() }; + exist_full == input_full + } + } + + fn value(&self, buffer_index: usize, offset: usize, length: usize) -> &[u8] { + debug_assert!(buffer_index <= self.completed.len()); + + if buffer_index < self.completed.len() { + let block = &self.completed[buffer_index]; + &block[offset..offset + length] + } else { + &self.in_progress[offset..offset + length] + } + } + + fn build_inner(self) -> ArrayRef { + let Self { + views, + in_progress, + mut completed, + nulls, + .. + } = self; + + // Build nulls + let null_buffer = nulls.build(); + + // Build values + // Flush `in_process` firstly + if !in_progress.is_empty() { + let buffer = Buffer::from(in_progress); + completed.push(buffer); + } + + let views = ScalarBuffer::from(views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + Arc::new(GenericByteViewArray::::new_unchecked( + views, + completed, + null_buffer, + )) + } + } + + fn take_n_inner(&mut self, n: usize) -> ArrayRef { + debug_assert!(self.len() >= n); + + // The `n == len` case, we need to take all + if self.len() == n { + let new_builder = Self::new().with_max_block_size(self.max_block_size); + let cur_builder = replace(self, new_builder); + return cur_builder.build_inner(); + } + + // The `n < len` case + // Take n for nulls + let null_buffer = self.nulls.take_n(n); + + // Take n for values: + // - Take first n `view`s from `views` + // + // - Find the last non-inlined `view`, if all inlined, + // we can build array and return happily, otherwise we + // we need to continue to process related buffers + // + // - Get the last related `buffer index`(let's name it `buffer index n`) + // from last non-inlined `view` + // + // - Take buffers, the key is that we need to know if we need to take + // the whole last related buffer. The logic is a bit complex, you can + // detail in `take_buffers_with_whole_last`, `take_buffers_with_partial_last` + // and other related steps in following + // + // - Shift the `buffer index` of remaining non-inlined `views` + // + let first_n_views = self.views.drain(0..n).collect::>(); + + let last_non_inlined_view = first_n_views + .iter() + .rev() + .find(|view| ((**view) as u32) > 12); + + // All taken views inlined + let Some(view) = last_non_inlined_view else { + let views = ScalarBuffer::from(first_n_views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + return Arc::new(GenericByteViewArray::::new_unchecked( + views, + Vec::new(), + null_buffer, + )); + } + }; + + // Unfortunately, some taken views non-inlined + let view = ByteView::from(*view); + let last_remaining_buffer_index = view.buffer_index as usize; + + // Check should we take the whole `last_remaining_buffer_index` buffer + let take_whole_last_buffer = self.should_take_whole_buffer( + last_remaining_buffer_index, + (view.offset + view.length) as usize, + ); + + // Take related buffers + let buffers = if take_whole_last_buffer { + self.take_buffers_with_whole_last(last_remaining_buffer_index) + } else { + self.take_buffers_with_partial_last( + last_remaining_buffer_index, + (view.offset + view.length) as usize, + ) + }; + + // Shift `buffer index`s finally + let shifts = if take_whole_last_buffer { + last_remaining_buffer_index + 1 + } else { + last_remaining_buffer_index + }; + + self.views.iter_mut().for_each(|view| { + if (*view as u32) > 12 { + let mut byte_view = ByteView::from(*view); + byte_view.buffer_index -= shifts as u32; + *view = byte_view.as_u128(); + } + }); + + // Build array and return + let views = ScalarBuffer::from(first_n_views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + Arc::new(GenericByteViewArray::::new_unchecked( + views, + buffers, + null_buffer, + )) + } + } + + fn take_buffers_with_whole_last( + &mut self, + last_remaining_buffer_index: usize, + ) -> Vec { + if last_remaining_buffer_index == self.completed.len() { + self.flush_in_progress(); + } + self.completed + .drain(0..last_remaining_buffer_index + 1) + .collect() + } + + fn take_buffers_with_partial_last( + &mut self, + last_remaining_buffer_index: usize, + last_take_len: usize, + ) -> Vec { + let mut take_buffers = Vec::with_capacity(last_remaining_buffer_index + 1); + + // Take `0 ~ last_remaining_buffer_index - 1` buffers + if !self.completed.is_empty() || last_remaining_buffer_index == 0 { + take_buffers.extend(self.completed.drain(0..last_remaining_buffer_index)); + } + + // Process the `last_remaining_buffer_index` buffers + let last_buffer = if last_remaining_buffer_index < self.completed.len() { + // If it is in `completed`, simply clone + self.completed[last_remaining_buffer_index].clone() + } else { + // If it is `in_progress`, copied `0 ~ offset` part + let taken_last_buffer = self.in_progress[0..last_take_len].to_vec(); + Buffer::from_vec(taken_last_buffer) + }; + take_buffers.push(last_buffer); + + take_buffers + } + + #[inline] + fn should_take_whole_buffer(&self, buffer_index: usize, take_len: usize) -> bool { + if buffer_index < self.completed.len() { + take_len == self.completed[buffer_index].len() + } else { + take_len == self.in_progress.len() + } + } + + fn flush_in_progress(&mut self) { + let flushed_block = replace( + &mut self.in_progress, + Vec::with_capacity(self.max_block_size), + ); + let buffer = Buffer::from_vec(flushed_block); + self.completed.push(buffer); + } +} + +impl GroupColumn for ByteViewGroupValueBuilder { + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + self.equal_to_inner(lhs_row, array, rhs_row) + } + + fn append_val(&mut self, array: &ArrayRef, row: usize) { + self.append_val_inner(array, row) + } + + fn vectorized_equal_to( + &self, + group_indices: &[usize], + array: &ArrayRef, + rows: &[usize], + equal_to_results: &mut [bool], + ) { + self.vectorized_equal_to_inner(group_indices, array, rows, equal_to_results); + } + + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + self.vectorized_append_inner(array, rows); + } + + fn len(&self) -> usize { + self.views.len() + } + + fn size(&self) -> usize { + let buffers_size = self + .completed + .iter() + .map(|buf| buf.capacity() * size_of::()) + .sum::(); + + self.nulls.allocated_size() + + self.views.capacity() * size_of::() + + self.in_progress.capacity() * size_of::() + + buffers_size + + size_of::() + } + + fn build(self: Box) -> ArrayRef { + Self::build_inner(*self) + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + self.take_n_inner(n) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::aggregates::group_values::multi_column::bytes_view::ByteViewGroupValueBuilder; + use arrow::array::AsArray; + use arrow::datatypes::StringViewType; + use arrow_array::{ArrayRef, StringViewArray}; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + + use super::GroupColumn; + + #[test] + fn test_byte_view_append_val() { + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let builder_array = StringViewArray::from(vec![ + Some("this string is quite long"), // in buffer 0 + Some("foo"), + None, + Some("bar"), + Some("this string is also quite long"), // buffer 0 + Some("this string is quite long"), // buffer 1 + Some("bar"), + ]); + let builder_array: ArrayRef = Arc::new(builder_array); + for row in 0..builder_array.len() { + builder.append_val(&builder_array, row); + } + + let output = Box::new(builder).build(); + // should be 2 output buffers to hold all the data + assert_eq!(output.as_string_view().data_buffers().len(), 2); + assert_eq!(&output, &builder_array) + } + + #[test] + fn test_byte_view_equal_to() { + let append = |builder: &mut ByteViewGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + for &index in append_rows { + builder.append_val(builder_array, index); + } + }; + + let equal_to = |builder: &ByteViewGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); + } + }; + + test_byte_view_equal_to_internal(append, equal_to); + } + + #[test] + fn test_byte_view_vectorized_equal_to() { + let append = |builder: &mut ByteViewGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + builder.vectorized_append(builder_array, append_rows); + }; + + let equal_to = |builder: &ByteViewGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; + + test_byte_view_equal_to_internal(append, equal_to); + } + + #[test] + fn test_byte_view_vectorized_operation_special_case() { + // Test the special `all nulls` or `not nulls` input array case + // for vectorized append and equal to + + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + + // All nulls input array + let all_nulls_input_array = Arc::new(StringViewArray::from(vec![ + Option::<&str>::None, + None, + None, + None, + None, + ])) as _; + builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[0, 1, 2, 3, 4], + &all_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + + // All not nulls input array + let all_not_nulls_input_array = Arc::new(StringViewArray::from(vec![ + Some("stringview1"), + Some("stringview2"), + Some("stringview3"), + Some("stringview4"), + Some("stringview5"), + ])) as _; + builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[5, 6, 7, 8, 9], + &all_not_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + } + + fn test_byte_view_equal_to_internal(mut append: A, mut equal_to: E) + where + A: FnMut(&mut ByteViewGroupValueBuilder, &ArrayRef, &[usize]), + E: FnMut( + &ByteViewGroupValueBuilder, + &[usize], + &ArrayRef, + &[usize], + &mut Vec, + ), + { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; value lens not equal + // - exist not null, input not null; value not equal(inlined case) + // - exist not null, input not null; value equal(inlined case) + // + // - exist not null, input not null; value not equal + // (non-inlined case + prefix not equal) + // + // - exist not null, input not null; value not equal + // (non-inlined case + value in `completed`) + // + // - exist not null, input not null; value equal + // (non-inlined case + value in `completed`) + // + // - exist not null, input not null; value not equal + // (non-inlined case + value in `in_progress`) + // + // - exist not null, input not null; value equal + // (non-inlined case + value in `in_progress`) + + // Set the block size to 40 for ensuring some unlined values are in `in_progress`, + // and some are in `completed`, so both two branches in `value` function can be covered. + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let builder_array = Arc::new(StringViewArray::from(vec![ + None, + None, + None, + Some("foo"), + Some("bazz"), + Some("foo"), + Some("bar"), + Some("I am a long string for test eq in completed"), + Some("I am a long string for test eq in progress"), + ])) as ArrayRef; + append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5, 6, 7, 8]); + + // Define input array + let (views, buffer, _nulls) = StringViewArray::from(vec![ + Some("foo"), + Some("bar"), // set to null + None, + None, + Some("baz"), + Some("oof"), + Some("bar"), + Some("i am a long string for test eq in completed"), + Some("I am a long string for test eq in COMPLETED"), + Some("I am a long string for test eq in completed"), + Some("I am a long string for test eq in PROGRESS"), + Some("I am a long string for test eq in progress"), + ]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(9); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some("bar") to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = + Arc::new(StringViewArray::new(views, buffer, Some(nulls))) as ArrayRef; + + // Check + let mut equal_to_results = vec![true; input_array.len()]; + equal_to( + &builder, + &[0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 8, 8], + &input_array, + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + &mut equal_to_results, + ); + + assert!(!equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(!equal_to_results[3]); + assert!(!equal_to_results[4]); + assert!(!equal_to_results[5]); + assert!(equal_to_results[6]); + assert!(!equal_to_results[7]); + assert!(!equal_to_results[8]); + assert!(equal_to_results[9]); + assert!(!equal_to_results[10]); + assert!(equal_to_results[11]); + } + + #[test] + fn test_byte_view_take_n() { + // ####### Define cases and init ####### + + // `take_n` is really complex, we should consider and test following situations: + // 1. Take nulls + // 2. Take all `inlined`s + // 3. Take non-inlined + partial last buffer in `completed` + // 4. Take non-inlined + whole last buffer in `completed` + // 5. Take non-inlined + partial last `in_progress` + // 6. Take non-inlined + whole last buffer in `in_progress` + // 7. Take all views at once + + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let input_array = StringViewArray::from(vec![ + // Test situation 1 + None, + None, + // Test situation 2 (also test take null together) + None, + Some("foo"), + Some("bar"), + // Test situation 3 (also test take null + inlined) + None, + Some("foo"), + Some("this string is quite long"), + Some("this string is also quite long"), + // Test situation 4 (also test take null + inlined) + None, + Some("bar"), + Some("this string is quite long"), + // Test situation 5 (also test take null + inlined) + None, + Some("foo"), + Some("another string that is is quite long"), + Some("this string not so long"), + // Test situation 6 (also test take null + inlined + insert again after taking) + None, + Some("bar"), + Some("this string is quite long"), + // Insert 4 and just take 3 to ensure it will go the path of situation 6 + None, + // Finally, we create a new builder, insert the whole array and then + // take whole at once for testing situation 7 + ]); + + let input_array: ArrayRef = Arc::new(input_array); + let first_ones_to_append = 16; // For testing situation 1~5 + let second_ones_to_append = 4; // For testing situation 6 + let final_ones_to_append = input_array.len(); // For testing situation 7 + + // ####### Test situation 1~5 ####### + for row in 0..first_ones_to_append { + builder.append_val(&input_array, row); + } + + assert_eq!(builder.completed.len(), 2); + assert_eq!(builder.in_progress.len(), 59); + + // Situation 1 + let taken_array = builder.take_n(2); + assert_eq!(&taken_array, &input_array.slice(0, 2)); + + // Situation 2 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(2, 3)); + + // Situation 3 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(5, 3)); + + let taken_array = builder.take_n(1); + assert_eq!(&taken_array, &input_array.slice(8, 1)); + + // Situation 4 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(9, 3)); + + // Situation 5 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(12, 3)); + + let taken_array = builder.take_n(1); + assert_eq!(&taken_array, &input_array.slice(15, 1)); + + // ####### Test situation 6 ####### + assert!(builder.completed.is_empty()); + assert!(builder.in_progress.is_empty()); + assert!(builder.views.is_empty()); + + for row in first_ones_to_append..first_ones_to_append + second_ones_to_append { + builder.append_val(&input_array, row); + } + + assert!(builder.completed.is_empty()); + assert_eq!(builder.in_progress.len(), 25); + + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(16, 3)); + + // ####### Test situation 7 ####### + // Create a new builder + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + + for row in 0..final_ones_to_append { + builder.append_val(&input_array, row); + } + + assert_eq!(builder.completed.len(), 3); + assert_eq!(builder.in_progress.len(), 25); + + let taken_array = builder.take_n(final_ones_to_append); + assert_eq!(&taken_array, &input_array); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_column/mod.rs similarity index 95% rename from datafusion/physical-plan/src/aggregates/group_values/column.rs rename to datafusion/physical-plan/src/aggregates/group_values/multi_column/mod.rs index 8100bb876ded..191292c549f4 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_column/mod.rs @@ -15,11 +15,15 @@ // specific language governing permissions and limitations // under the License. +mod bytes; +mod bytes_view; +mod primitive; + use std::mem::{self, size_of}; -use crate::aggregates::group_values::group_column::{ - ByteGroupValueBuilder, ByteViewGroupValueBuilder, GroupColumn, - PrimitiveGroupValueBuilder, +use crate::aggregates::group_values::multi_column::{ + bytes::ByteGroupValueBuilder, bytes_view::ByteViewGroupValueBuilder, + primitive::PrimitiveGroupValueBuilder, }; use crate::aggregates::group_values::GroupValues; use ahash::RandomState; @@ -43,6 +47,72 @@ use hashbrown::raw::RawTable; const NON_INLINED_FLAG: u64 = 0x8000000000000000; const VALUE_MASK: u64 = 0x7FFFFFFFFFFFFFFF; +/// Trait for storing a single column of group values in [`GroupValuesColumn`] +/// +/// Implementations of this trait store an in-progress collection of group values +/// (similar to various builders in Arrow-rs) that allow for quick comparison to +/// incoming rows. +/// +/// [`GroupValuesColumn`]: crate::aggregates::group_values::GroupValuesColumn +pub trait GroupColumn: Send + Sync { + /// Returns equal if the row stored in this builder at `lhs_row` is equal to + /// the row in `array` at `rhs_row` + /// + /// Note that this comparison returns true if both elements are NULL + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool; + + /// Appends the row at `row` in `array` to this builder + fn append_val(&mut self, array: &ArrayRef, row: usize); + + /// The vectorized version equal to + /// + /// When found nth row stored in this builder at `lhs_row` + /// is equal to the row in `array` at `rhs_row`, + /// it will record the `true` result at the corresponding + /// position in `equal_to_results`. + /// + /// And if found nth result in `equal_to_results` is already + /// `false`, the check for nth row will be skipped. + /// + fn vectorized_equal_to( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ); + + /// The vectorized version `append_val` + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]); + + /// Returns the number of rows stored in this builder + fn len(&self) -> usize; + + /// Returns the number of bytes used by this [`GroupColumn`] + fn size(&self) -> usize; + + /// Builds a new array from all of the stored rows + fn build(self: Box) -> ArrayRef; + + /// Builds a new array from the first `n` stored rows, shifting the + /// remaining rows to the start of the builder + fn take_n(&mut self, n: usize) -> ArrayRef; +} + +/// Determines if the nullability of the existing and new input array can be used +/// to short-circuit the comparison of the two values. +/// +/// Returns `Some(result)` if the result of the comparison can be determined +/// from the nullness of the two values, and `None` if the comparison must be +/// done on the values themselves. +pub fn nulls_equal_to(lhs_null: bool, rhs_null: bool) -> Option { + match (lhs_null, rhs_null) { + (true, true) => Some(true), + (false, true) | (true, false) => Some(false), + _ => None, + } +} + /// The view of indices pointing to the actual values in `GroupValues` /// /// If only single `group index` represented by view, @@ -1068,7 +1138,7 @@ mod tests { use datafusion_common::utils::proxy::RawTableAllocExt; use datafusion_expr::EmitTo; - use crate::aggregates::group_values::{column::GroupValuesColumn, GroupValues}; + use crate::aggregates::group_values::{multi_column::GroupValuesColumn, GroupValues}; use super::GroupIndexView; diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_column/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_column/primitive.rs new file mode 100644 index 000000000000..dff85ff7eb1a --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_column/primitive.rs @@ -0,0 +1,472 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::multi_column::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow::buffer::ScalarBuffer; +use arrow_array::cast::AsArray; +use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use itertools::izip; +use std::iter; +use std::sync::Arc; + +/// An implementation of [`GroupColumn`] for primitive values +/// +/// Optimized to skip null buffer construction if the input is known to be non nullable +/// +/// # Template parameters +/// +/// `T`: the native Rust type that stores the data +/// `NULLABLE`: if the data can contain any nulls +#[derive(Debug)] +pub struct PrimitiveGroupValueBuilder { + group_values: Vec, + nulls: MaybeNullBufferBuilder, +} + +impl PrimitiveGroupValueBuilder +where + T: ArrowPrimitiveType, +{ + /// Create a new `PrimitiveGroupValueBuilder` + pub fn new() -> Self { + Self { + group_values: vec![], + nulls: MaybeNullBufferBuilder::new(), + } + } +} + +impl GroupColumn + for PrimitiveGroupValueBuilder +{ + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + // Perf: skip null check (by short circuit) if input is not nullable + if NULLABLE { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + // Otherwise, we need to check their values + } + + self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) + } + + fn append_val(&mut self, array: &ArrayRef, row: usize) { + // Perf: skip null check if input can't have nulls + if NULLABLE { + if array.is_null(row) { + self.nulls.append(true); + self.group_values.push(T::default_value()); + } else { + self.nulls.append(false); + self.group_values.push(array.as_primitive::().value(row)); + } + } else { + self.group_values.push(array.as_primitive::().value(row)); + } + } + + fn vectorized_equal_to( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) { + let array = array.as_primitive::(); + + let iter = izip!( + lhs_rows.iter(), + rhs_rows.iter(), + equal_to_results.iter_mut(), + ); + + for (&lhs_row, &rhs_row, equal_to_result) in iter { + // Has found not equal to in previous column, don't need to check + if !*equal_to_result { + continue; + } + + // Perf: skip null check (by short circuit) if input is not nullable + if NULLABLE { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + *equal_to_result = result; + continue; + } + // Otherwise, we need to check their values + } + + *equal_to_result = self.group_values[lhs_row] == array.value(rhs_row); + } + } + + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + let arr = array.as_primitive::(); + + let null_count = array.null_count(); + let num_rows = array.len(); + let all_null_or_non_null = if null_count == 0 { + Some(true) + } else if null_count == num_rows { + Some(false) + } else { + None + }; + + match (NULLABLE, all_null_or_non_null) { + (true, None) => { + for &row in rows { + if array.is_null(row) { + self.nulls.append(true); + self.group_values.push(T::default_value()); + } else { + self.nulls.append(false); + self.group_values.push(arr.value(row)); + } + } + } + + (true, Some(true)) => { + self.nulls.append_n(rows.len(), false); + for &row in rows { + self.group_values.push(arr.value(row)); + } + } + + (true, Some(false)) => { + self.nulls.append_n(rows.len(), true); + self.group_values + .extend(iter::repeat(T::default_value()).take(rows.len())); + } + + (false, _) => { + for &row in rows { + self.group_values.push(arr.value(row)); + } + } + } + } + + fn len(&self) -> usize { + self.group_values.len() + } + + fn size(&self) -> usize { + self.group_values.allocated_size() + self.nulls.allocated_size() + } + + fn build(self: Box) -> ArrayRef { + let Self { + group_values, + nulls, + } = *self; + + let nulls = nulls.build(); + if !NULLABLE { + assert!(nulls.is_none(), "unexpected nulls in non nullable input"); + } + + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(group_values), + nulls, + )) + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + let first_n = self.group_values.drain(0..n).collect::>(); + + let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; + + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(first_n), + first_n_nulls, + )) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::aggregates::group_values::multi_column::primitive::PrimitiveGroupValueBuilder; + use arrow::datatypes::Int64Type; + use arrow_array::{ArrayRef, Int64Array}; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + + use super::GroupColumn; + + #[test] + fn test_nullable_primitive_equal_to() { + let append = |builder: &mut PrimitiveGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + for &index in append_rows { + builder.append_val(builder_array, index); + } + }; + + let equal_to = |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); + } + }; + + test_nullable_primitive_equal_to_internal(append, equal_to); + } + + #[test] + fn test_nullable_primitive_vectorized_equal_to() { + let append = |builder: &mut PrimitiveGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + builder.vectorized_append(builder_array, append_rows); + }; + + let equal_to = |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; + + test_nullable_primitive_equal_to_internal(append, equal_to); + } + + fn test_nullable_primitive_equal_to_internal(mut append: A, mut equal_to: E) + where + A: FnMut(&mut PrimitiveGroupValueBuilder, &ArrayRef, &[usize]), + E: FnMut( + &PrimitiveGroupValueBuilder, + &[usize], + &ArrayRef, + &[usize], + &mut Vec, + ), + { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = PrimitiveGroupValueBuilder::::new(); + let builder_array = Arc::new(Int64Array::from(vec![ + None, + None, + None, + Some(1), + Some(2), + Some(3), + ])) as ArrayRef; + append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]); + + // Define input array + let (_nulls, values, _) = + Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some(2) to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as ArrayRef; + + // Check + let mut equal_to_results = vec![true; builder.len()]; + equal_to( + &builder, + &[0, 1, 2, 3, 4, 5], + &input_array, + &[0, 1, 2, 3, 4, 5], + &mut equal_to_results, + ); + + assert!(!equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(!equal_to_results[3]); + assert!(!equal_to_results[4]); + assert!(equal_to_results[5]); + } + + #[test] + fn test_not_nullable_primitive_equal_to() { + let append = |builder: &mut PrimitiveGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + for &index in append_rows { + builder.append_val(builder_array, index); + } + }; + + let equal_to = |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); + } + }; + + test_not_nullable_primitive_equal_to_internal(append, equal_to); + } + + #[test] + fn test_not_nullable_primitive_vectorized_equal_to() { + let append = |builder: &mut PrimitiveGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + builder.vectorized_append(builder_array, append_rows); + }; + + let equal_to = |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; + + test_not_nullable_primitive_equal_to_internal(append, equal_to); + } + + fn test_not_nullable_primitive_equal_to_internal(mut append: A, mut equal_to: E) + where + A: FnMut(&mut PrimitiveGroupValueBuilder, &ArrayRef, &[usize]), + E: FnMut( + &PrimitiveGroupValueBuilder, + &[usize], + &ArrayRef, + &[usize], + &mut Vec, + ), + { + // Will cover such cases: + // - values equal + // - values not equal + + // Define PrimitiveGroupValueBuilder + let mut builder = PrimitiveGroupValueBuilder::::new(); + let builder_array = + Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef; + append(&mut builder, &builder_array, &[0, 1]); + + // Define input array + let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef; + + // Check + let mut equal_to_results = vec![true; builder.len()]; + equal_to( + &builder, + &[0, 1], + &input_array, + &[0, 1], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(!equal_to_results[1]); + } + + #[test] + fn test_nullable_primitive_vectorized_operation_special_case() { + // Test the special `all nulls` or `not nulls` input array case + // for vectorized append and equal to + + let mut builder = PrimitiveGroupValueBuilder::::new(); + + // All nulls input array + let all_nulls_input_array = Arc::new(Int64Array::from(vec![ + Option::::None, + None, + None, + None, + None, + ])) as _; + builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[0, 1, 2, 3, 4], + &all_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + + // All not nulls input array + let all_not_nulls_input_array = Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ])) as _; + builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[5, 6, 7, 8, 9], + &all_not_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + } +} From 5467a28ea83f006ebd7aedeeff33ce21da848299 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Wed, 13 Nov 2024 01:18:45 -0500 Subject: [PATCH 14/17] feat: Add boolean column to aggregate queries for fuzz testing (#13331) * add bool col * clippy fix * remove change * fmt fix * typo fix --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 1 + .../aggregation_fuzzer/data_generator.rs | 40 +++++++++-- test-utils/src/array_gen/boolean.rs | 68 +++++++++++++++++++ test-utils/src/array_gen/mod.rs | 2 + 4 files changed, 106 insertions(+), 5 deletions(-) create mode 100644 test-utils/src/array_gen/boolean.rs diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 792e23b519e0..29e1d7bc22ec 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -225,6 +225,7 @@ fn baseline_config() -> DatasetGeneratorConfig { // low cardinality columns ColumnDescr::new("u8_low", DataType::UInt8).with_max_num_distinct(10), ColumnDescr::new("utf8_low", DataType::Utf8).with_max_num_distinct(10), + ColumnDescr::new("bool", DataType::Boolean), ColumnDescr::new("binary", DataType::Binary), ColumnDescr::new("large_binary", DataType::LargeBinary), ColumnDescr::new("binaryview", DataType::BinaryView), diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index fd4e3c40db2a..e4c0cb6fe77f 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -18,9 +18,9 @@ use std::sync::Arc; use arrow::datatypes::{ - BinaryType, BinaryViewType, ByteArrayType, ByteViewType, Date32Type, Date64Type, - Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, Int32Type, - Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, + BinaryType, BinaryViewType, BooleanType, ByteArrayType, ByteViewType, Date32Type, + Date64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, LargeBinaryType, LargeUtf8Type, StringViewType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, @@ -38,8 +38,8 @@ use rand::{ }; use test_utils::{ array_gen::{ - BinaryArrayGenerator, DecimalArrayGenerator, PrimitiveArrayGenerator, - StringArrayGenerator, + BinaryArrayGenerator, BooleanArrayGenerator, DecimalArrayGenerator, + PrimitiveArrayGenerator, StringArrayGenerator, }, stagger_batch, }; @@ -269,6 +269,26 @@ macro_rules! generate_decimal_array { }}; } +// Generating `BooleanArray` due to it being a special type in Arrow (bit-packed) +macro_rules! generate_boolean_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ + // Select a null percentage from the candidate percentages + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + + let num_distinct_booleans = if $MAX_NUM_DISTINCT >= 2 { 2 } else { 1 }; + + let mut generator = BooleanArrayGenerator { + num_booleans: $NUM_ROWS, + num_distinct_booleans, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() + }}; +} + macro_rules! generate_primitive_array { ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); @@ -689,6 +709,16 @@ impl RecordBatchGenerator { StringViewType ) } + DataType::Boolean => { + generate_boolean_array! { + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + BooleanType + } + } _ => { panic!("Unsupported data generator type: {}", col.column_type) } diff --git a/test-utils/src/array_gen/boolean.rs b/test-utils/src/array_gen/boolean.rs new file mode 100644 index 000000000000..f3b83dd245f7 --- /dev/null +++ b/test-utils/src/array_gen/boolean.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, BooleanArray, BooleanBuilder, UInt32Array}; +use arrow::compute::take; +use rand::rngs::StdRng; +use rand::Rng; + +/// Randomly generate boolean arrays +pub struct BooleanArrayGenerator { + pub num_booleans: usize, + pub num_distinct_booleans: usize, + pub null_pct: f64, + pub rng: StdRng, +} + +impl BooleanArrayGenerator { + /// Generate BooleanArray with bit-packed values + pub fn gen_data(&mut self) -> ArrayRef { + // Table of booleans from which to draw (distinct means 1 or 2) + let distinct_booleans: BooleanArray = match self.num_distinct_booleans { + 1 => { + let value = self.rng.gen::(); + let mut builder = BooleanBuilder::with_capacity(1); + builder.append_value(value); + builder.finish() + } + 2 => { + let mut builder = BooleanBuilder::with_capacity(2); + builder.append_value(true); + builder.append_value(false); + builder.finish() + } + _ => unreachable!(), + }; + + // Generate indices to select from the distinct booleans + let indices: UInt32Array = (0..self.num_booleans) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_booleans > 1 { + Some(self.rng.gen_range(0..self.num_distinct_booleans as u32)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + + take(&distinct_booleans, &indices, options).unwrap() + } +} diff --git a/test-utils/src/array_gen/mod.rs b/test-utils/src/array_gen/mod.rs index d076bb1b6f0b..1d420c543f9f 100644 --- a/test-utils/src/array_gen/mod.rs +++ b/test-utils/src/array_gen/mod.rs @@ -16,12 +16,14 @@ // under the License. mod binary; +mod boolean; mod decimal; mod primitive; mod random_data; mod string; pub use binary::BinaryArrayGenerator; +pub use boolean::BooleanArrayGenerator; pub use decimal::DecimalArrayGenerator; pub use primitive::PrimitiveArrayGenerator; pub use string::StringArrayGenerator; From fd092e0579da5121ac75130f6b2e92da47034308 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Tue, 12 Nov 2024 22:22:32 -0800 Subject: [PATCH 15/17] Move filtered SMJ Full filtered join out of `join_partial` phase (#13369) * Move filtered SMJ Full filtered join out of `join_partial` phase * Move filtered SMJ Full filtered join out of `join_partial` phase * Move filtered SMJ Full filtered join out of `join_partial` phase --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 43 ++- .../src/joins/sort_merge_join.rs | 336 ++++++++++++------ .../test_files/sort_merge_join.slt | 33 +- 3 files changed, 254 insertions(+), 158 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index d7a3460e4987..cf1742a30e66 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -41,6 +41,7 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::memory::MemoryExec; +use crate::fuzz_cases::join_fuzz::JoinTestType::{HjSmj, NljHj}; use datafusion::prelude::{SessionConfig, SessionContext}; use test_utils::stagger_batch_with_seed; @@ -96,7 +97,7 @@ async fn test_inner_join_1k_filtered() { JoinType::Inner, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -108,7 +109,7 @@ async fn test_inner_join_1k() { JoinType::Inner, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -120,7 +121,7 @@ async fn test_left_join_1k() { JoinType::Left, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -132,7 +133,7 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -144,7 +145,7 @@ async fn test_right_join_1k() { JoinType::Right, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -156,7 +157,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -168,13 +169,11 @@ async fn test_full_join_1k() { JoinType::Full, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_full_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -182,7 +181,7 @@ async fn test_full_join_1k_filtered() { JoinType::Full, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[NljHj, HjSmj], false) .await } @@ -194,7 +193,7 @@ async fn test_semi_join_1k() { JoinType::LeftSemi, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -206,7 +205,7 @@ async fn test_semi_join_1k_filtered() { JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -218,7 +217,7 @@ async fn test_anti_join_1k() { JoinType::LeftAnti, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -230,7 +229,7 @@ async fn test_anti_join_1k_filtered() { JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -242,7 +241,7 @@ async fn test_left_mark_join_1k() { JoinType::LeftMark, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -254,7 +253,7 @@ async fn test_left_mark_join_1k_filtered() { JoinType::LeftMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -512,8 +511,8 @@ impl JoinFuzzTestCase { nlj_formatted_sorted.sort_unstable(); if debug - && ((join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows) - || (join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows)) + && ((join_tests.contains(&NljHj) && nlj_rows != hj_rows) + || (join_tests.contains(&HjSmj) && smj_rows != hj_rows)) { let fuzz_debug = "fuzz_test_debug"; std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); @@ -533,7 +532,7 @@ impl JoinFuzzTestCase { "input2", ); - if join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows { + if join_tests.contains(&NljHj) && nlj_rows != hj_rows { println!("=============== HashJoinExec =================="); hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); println!("=============== NestedLoopJoinExec =================="); @@ -551,7 +550,7 @@ impl JoinFuzzTestCase { ); } - if join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows { + if join_tests.contains(&HjSmj) && smj_rows != hj_rows { println!("=============== HashJoinExec =================="); hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); println!("=============== SortMergeJoinExec =================="); @@ -570,7 +569,7 @@ impl JoinFuzzTestCase { } } - if join_tests.contains(&JoinTestType::NljHj) { + if join_tests.contains(&NljHj) { let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size); assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); @@ -591,7 +590,7 @@ impl JoinFuzzTestCase { } } - if join_tests.contains(&JoinTestType::HjSmj) { + if join_tests.contains(&HjSmj) { let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size); assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str()); diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 1eb6ea632923..9307caf1c6ad 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -852,6 +852,54 @@ fn get_corrected_filter_mask( corrected_mask.extend(vec![Some(true); null_matched]); Some(corrected_mask.finish()) } + JoinType::Full => { + let mut mask: Vec> = vec![Some(true); row_indices_length]; + let mut last_true_idx = 0; + let mut first_row_idx = 0; + let mut seen_false = false; + + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + let val = filter_mask.value(i); + let is_null = filter_mask.is_null(i); + + if val { + if !seen_true { + last_true_idx = i; + } + seen_true = true; + } + + if is_null || val { + mask[i] = Some(true); + } else if !is_null && !val && (seen_true || seen_false) { + mask[i] = None; + } else { + mask[i] = Some(false); + } + + if !is_null && !val { + seen_false = true; + } + + if last_index { + if seen_true { + #[allow(clippy::needless_range_loop)] + for j in first_row_idx..last_true_idx { + mask[j] = None; + } + } + + seen_true = false; + seen_false = false; + last_true_idx = 0; + first_row_idx = i + 1; + } + } + + Some(BooleanArray::from(mask)) + } // Only outer joins needs to keep track of processed rows and apply corrected filter mask _ => None, } @@ -887,6 +935,7 @@ impl Stream for SMJStream { | JoinType::LeftMark | JoinType::Right | JoinType::LeftAnti + | JoinType::Full ) { self.freeze_all()?; @@ -969,6 +1018,7 @@ impl Stream for SMJStream { | JoinType::Right | JoinType::LeftAnti | JoinType::LeftMark + | JoinType::Full ) { continue; @@ -990,6 +1040,7 @@ impl Stream for SMJStream { | JoinType::LeftSemi | JoinType::Right | JoinType::LeftAnti + | JoinType::Full | JoinType::LeftMark ) { @@ -1171,9 +1222,10 @@ impl SMJStream { // If the head batch is fully processed, dequeue it and produce output of it. if head_batch.range.end == head_batch.num_rows { self.freeze_dequeuing_buffered()?; - if let Some(buffered_batch) = + if let Some(mut buffered_batch) = self.buffered_data.batches.pop_front() { + self.produce_buffered_not_matched(&mut buffered_batch)?; self.free_reservation(buffered_batch)?; } } else { @@ -1401,8 +1453,8 @@ impl SMJStream { } fn freeze_all(&mut self) -> Result<()> { + self.freeze_buffered(self.buffered_data.batches.len())?; self.freeze_streamed()?; - self.freeze_buffered(self.buffered_data.batches.len(), false)?; Ok(()) } @@ -1413,7 +1465,7 @@ impl SMJStream { fn freeze_dequeuing_buffered(&mut self) -> Result<()> { self.freeze_streamed()?; // Only freeze and produce the first batch in buffered_data as the batch is fully processed - self.freeze_buffered(1, true)?; + self.freeze_buffered(1)?; Ok(()) } @@ -1422,13 +1474,7 @@ impl SMJStream { // // Applicable only in case of Full join. // - // If `output_not_matched_filter` is true, this will also produce record batches - // for buffered rows which are joined with streamed side but don't match join filter. - fn freeze_buffered( - &mut self, - batch_count: usize, - output_not_matched_filter: bool, - ) -> Result<()> { + fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> { if !matches!(self.join_type, JoinType::Full) { return Ok(()); } @@ -1442,34 +1488,66 @@ impl SMJStream { &buffered_indices, buffered_batch, )? { + let num_rows = record_batch.num_rows(); + self.output_record_batches + .filter_mask + .extend(&BooleanArray::from(vec![None; num_rows])); + self.output_record_batches + .row_indices + .extend(&UInt64Array::from(vec![None; num_rows])); + self.output_record_batches + .batch_ids + .extend(vec![0; num_rows]); + self.output_record_batches.batches.push(record_batch); } buffered_batch.null_joined.clear(); + } + Ok(()) + } - // For buffered row which is joined with streamed side rows but all joined rows - // don't satisfy the join filter - if output_not_matched_filter { - let not_matched_buffered_indices = buffered_batch - .join_filter_failed_map - .iter() - .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) - .collect::>(); + fn produce_buffered_not_matched( + &mut self, + buffered_batch: &mut BufferedBatch, + ) -> Result<()> { + if !matches!(self.join_type, JoinType::Full) { + return Ok(()); + } - let buffered_indices = UInt64Array::from_iter_values( - not_matched_buffered_indices.iter().copied(), - ); + // For buffered row which is joined with streamed side rows but all joined rows + // don't satisfy the join filter + let not_matched_buffered_indices = buffered_batch + .join_filter_failed_map + .iter() + .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) + .collect::>(); - if let Some(record_batch) = produce_buffered_null_batch( - &self.schema, - &self.streamed_schema, - &buffered_indices, - buffered_batch, - )? { - self.output_record_batches.batches.push(record_batch); - } - buffered_batch.join_filter_failed_map.clear(); - } + let buffered_indices = + UInt64Array::from_iter_values(not_matched_buffered_indices.iter().copied()); + + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + //print_batches(&[record_batch.clone()]); + let num_rows = record_batch.num_rows(); + + self.output_record_batches + .filter_mask + .extend(&BooleanArray::from(vec![None; num_rows])); + self.output_record_batches + .row_indices + .extend(&UInt64Array::from(vec![None; num_rows])); + self.output_record_batches + .batch_ids + .extend(vec![0; num_rows]); + self.output_record_batches.batches.push(record_batch); } + //dbg!(&buffered_batch.join_filter_failed_map); + buffered_batch.join_filter_failed_map.clear(); + Ok(()) } @@ -1514,8 +1592,6 @@ impl SMJStream { ) }; - let streamed_columns_length = streamed_columns.len(); - // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered. let filter_columns = if chunk.buffered_batch_idx.is_some() { @@ -1587,6 +1663,7 @@ impl SMJStream { | JoinType::Right | JoinType::LeftAnti | JoinType::LeftMark + | JoinType::Full ) { self.output_record_batches .batches @@ -1596,7 +1673,11 @@ impl SMJStream { self.output_record_batches.batches.push(filtered_batch); } - self.output_record_batches.filter_mask.extend(&mask); + if !matches!(self.join_type, JoinType::Full) { + self.output_record_batches.filter_mask.extend(&mask); + } else { + self.output_record_batches.filter_mask.extend(pre_mask); + } self.output_record_batches .row_indices .extend(&streamed_indices); @@ -1610,83 +1691,26 @@ impl SMJStream { // I.e., if all rows joined from a streamed row are failed with the join filter, // we need to join it with nulls as buffered side. if matches!(self.join_type, JoinType::Full) { - // We need to get the mask for row indices that the joined rows are failed - // on the join filter. I.e., for a row in streamed side, if all joined rows - // between it and all buffered rows are failed on the join filter, we need to - // output it with null columns from buffered side. For the mask here, it - // behaves like LeftAnti join. - let not_mask = if mask.null_count() > 0 { - // If the mask contains nulls, we need to use `prep_null_mask_filter` to - // handle the nulls in the mask as false to produce rows where the mask - // was null itself. - compute::not(&compute::prep_null_mask_filter(&mask))? - } else { - compute::not(&mask)? - }; + let buffered_batch = &mut self.buffered_data.batches + [chunk.buffered_batch_idx.unwrap()]; - let null_joined_batch = - filter_record_batch(&output_batch, ¬_mask)?; - - let buffered_columns = self - .buffered_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - null_joined_batch.num_rows(), - ) - }) - .collect::>(); - - let columns = { - let mut streamed_columns = null_joined_batch - .columns() - .iter() - .take(streamed_columns_length) - .cloned() - .collect::>(); - - streamed_columns.extend(buffered_columns); - streamed_columns - }; - - // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = - RecordBatch::try_new(Arc::clone(&self.schema), columns)?; - - self.output_record_batches - .batches - .push(null_joined_streamed_batch); - - // For full join, we also need to output the null joined rows from the buffered side. - // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with - // streamed side, it won't be outputted by `freeze_buffered`. - // We need to check if a buffered row is joined with streamed side and output. - // If it is joined with streamed side, but doesn't match the join filter, - // we need to output it with nulls as streamed side. - if matches!(self.join_type, JoinType::Full) { - let buffered_batch = &mut self.buffered_data.batches - [chunk.buffered_batch_idx.unwrap()]; - - for i in 0..pre_mask.len() { - // If the buffered row is not joined with streamed side, - // skip it. - if buffered_indices.is_null(i) { - continue; - } + for i in 0..pre_mask.len() { + // If the buffered row is not joined with streamed side, + // skip it. + if buffered_indices.is_null(i) { + continue; + } - let buffered_index = buffered_indices.value(i); + let buffered_index = buffered_indices.value(i); - buffered_batch.join_filter_failed_map.insert( - buffered_index, - *buffered_batch - .join_filter_failed_map - .get(&buffered_index) - .unwrap_or(&true) - && !pre_mask.value(i), - ); - } + buffered_batch.join_filter_failed_map.insert( + buffered_index, + *buffered_batch + .join_filter_failed_map + .get(&buffered_index) + .unwrap_or(&true) + && !pre_mask.value(i), + ); } } } else { @@ -1726,6 +1750,7 @@ impl SMJStream { | JoinType::Right | JoinType::LeftAnti | JoinType::LeftMark + | JoinType::Full )) { self.output_record_batches.batches.clear(); @@ -1735,12 +1760,28 @@ impl SMJStream { fn filter_joined_batch(&mut self) -> Result { let record_batch = self.output_record_batch_and_reset()?; - let out_indices = self.output_record_batches.row_indices.finish(); - let out_mask = self.output_record_batches.filter_mask.finish(); + let mut out_indices = self.output_record_batches.row_indices.finish(); + let mut out_mask = self.output_record_batches.filter_mask.finish(); + let mut batch_ids = &self.output_record_batches.batch_ids; + let default_batch_ids = vec![0; record_batch.num_rows()]; + + if out_indices.null_count() == out_indices.len() + && out_indices.len() != record_batch.num_rows() + { + out_mask = BooleanArray::from(vec![None; record_batch.num_rows()]); + out_indices = UInt64Array::from(vec![None; record_batch.num_rows()]); + batch_ids = &default_batch_ids; + } + + if out_mask.is_empty() { + self.output_record_batches.batches.clear(); + return Ok(record_batch); + } + let maybe_corrected_mask = get_corrected_filter_mask( self.join_type, &out_indices, - &self.output_record_batches.batch_ids, + batch_ids, &out_mask, record_batch.num_rows(), ); @@ -1753,8 +1794,8 @@ impl SMJStream { let mut filtered_record_batch = filter_record_batch(&record_batch, corrected_mask)?; - let buffered_columns_length = self.buffered_schema.fields.len(); - let streamed_columns_length = self.streamed_schema.fields.len(); + let left_columns_length = self.streamed_schema.fields.len(); + let right_columns_length = self.buffered_schema.fields.len(); if matches!( self.join_type, @@ -1773,18 +1814,17 @@ impl SMJStream { let streamed_columns = null_joined_batch .columns() .iter() - .skip(buffered_columns_length) + .skip(left_columns_length) .cloned() .collect::>(); buffered_columns.extend(streamed_columns); buffered_columns } else { - // Left join or full outer join let mut streamed_columns = null_joined_batch .columns() .iter() - .take(streamed_columns_length) + .take(right_columns_length) .cloned() .collect::>(); @@ -1801,15 +1841,75 @@ impl SMJStream { &[filtered_record_batch, null_joined_streamed_batch], )?; } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { - let output_column_indices = (0..streamed_columns_length).collect::>(); + let output_column_indices = (0..left_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; + } else if matches!(self.join_type, JoinType::Full) + && corrected_mask.false_count() > 0 + { + // Find rows which joined by key but Filter predicate evaluated as false + let joined_filter_not_matched_mask = compute::not(corrected_mask)?; + let joined_filter_not_matched_batch = + filter_record_batch(&record_batch, &joined_filter_not_matched_mask)?; + + // Add left unmatched rows adding the right side as nulls + let right_null_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + joined_filter_not_matched_batch.num_rows(), + ) + }) + .collect::>(); + + let mut result_joined = joined_filter_not_matched_batch + .columns() + .iter() + .take(left_columns_length) + .cloned() + .collect::>(); + + result_joined.extend(right_null_columns); + + let left_null_joined_batch = + RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?; + + // Add right unmatched rows adding the left side as nulls + let mut result_joined = self + .streamed_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + joined_filter_not_matched_batch.num_rows(), + ) + }) + .collect::>(); + + let right_data = joined_filter_not_matched_batch + .columns() + .iter() + .skip(left_columns_length) + .cloned() + .collect::>(); + + result_joined.extend(right_data); + + filtered_record_batch = concat_batches( + &self.schema, + &[filtered_record_batch, left_null_joined_batch], + )?; } self.output_record_batches.batches.clear(); - self.output_record_batches.batch_ids = vec![]; + self.output_record_batches.batch_ids.clear(); self.output_record_batches.filter_mask = BooleanBuilder::new(); self.output_record_batches.row_indices = UInt64Builder::new(); + Ok(filtered_record_batch) } } diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index f4cc888d6b8e..9a20e7987ff6 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -126,24 +126,21 @@ Alice 50 Alice 1 Alice 50 Alice 2 Bob 1 NULL NULL -# Uncomment when filtered FULL moved -# full join with join filter -#query TITI rowsort -#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b -#---- -#Alice 100 NULL NULL -#Alice 50 Alice 2 -#Bob 1 NULL NULL -#NULL NULL Alice 1 - -# Uncomment when filtered FULL moved -#query TITI rowsort -#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 -#---- -#Alice 100 Alice 1 -#Alice 100 Alice 2 -#Alice 50 NULL NULL -#Bob 1 NULL NULL +query TITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b +---- +Alice 100 NULL NULL +Alice 50 Alice 2 +Bob 1 NULL NULL +NULL NULL Alice 1 + +query TITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 NULL NULL +Bob 1 NULL NULL statement ok DROP TABLE t1; From 4c1ec807fb70ef8abe96299760dd2faaa556a49c Mon Sep 17 00:00:00 2001 From: AnthonyZhOon <126740410+AnthonyZhOon@users.noreply.github.com> Date: Wed, 13 Nov 2024 19:47:16 +1100 Subject: [PATCH 16/17] Docs: Update dependencies in `requirements.txt` for python3.12 (#13339) * Update requirements.txt for python3.12 * Update requirements.txt with a minimum version for `setuptools` * Bump to python 3.12 in docs CI --- .github/workflows/docs.yaml | 4 ++-- docs/requirements.txt | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 44ca5aaf4eda..0b43339f57a6 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -26,7 +26,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.12" - name: Install dependencies run: | @@ -61,4 +61,4 @@ jobs: git add --all git commit -m 'Publish built docs triggered by ${{ github.sha }}' git push || git push --force - fi \ No newline at end of file + fi diff --git a/docs/requirements.txt b/docs/requirements.txt index 24546d59a45a..bd030fb67044 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -20,3 +20,4 @@ pydata-sphinx-theme==0.8.0 myst-parser maturin jinja2 +setuptools>=48.0.0 From 3a7dde3a1f974489b8b279c1143bc9cebedf7682 Mon Sep 17 00:00:00 2001 From: ding-young Date: Wed, 13 Nov 2024 21:01:43 +0900 Subject: [PATCH 17/17] Remove uses of #[allow(dead_code)] in favor of _identifier (#13328) * Remove uses of #[allow(dead_code)] in favor of _identifier * update comments --- datafusion-examples/examples/advanced_parquet_index.rs | 8 ++++---- datafusion/core/tests/parquet/external_access_plan.rs | 10 +++++----- datafusion/core/tests/parquet/mod.rs | 5 ++--- datafusion/execution/src/disk_manager.rs | 5 ++--- datafusion/physical-plan/src/joins/cross_join.rs | 5 ++--- datafusion/physical-plan/src/joins/hash_join.rs | 5 ++--- datafusion/physical-plan/src/joins/nested_loop_join.rs | 7 +++---- datafusion/physical-plan/src/repartition/mod.rs | 10 ++++------ datafusion/physical-plan/src/sorts/cursor.rs | 8 +++++--- 9 files changed, 29 insertions(+), 34 deletions(-) diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/advanced_parquet_index.rs index f6860bb5b87a..67b745d4074e 100644 --- a/datafusion-examples/examples/advanced_parquet_index.rs +++ b/datafusion-examples/examples/advanced_parquet_index.rs @@ -229,9 +229,9 @@ async fn main() -> Result<()> { /// `file1.parquet` contains values `0..1000` #[derive(Debug)] pub struct IndexTableProvider { - /// Where the file is stored (cleanup on drop) - #[allow(dead_code)] - tmpdir: TempDir, + /// Pointer to temporary file storage. Keeping it in scope to prevent temporary folder + /// to be deleted prematurely + _tmpdir: TempDir, /// The file that is being read. indexed_file: IndexedFile, /// The underlying object store @@ -250,7 +250,7 @@ impl IndexTableProvider { Ok(Self { indexed_file, - tmpdir, + _tmpdir: tmpdir, object_store, use_row_selections: AtomicBool::new(false), }) diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index 03afc858dfca..96267eeff5a7 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -313,7 +313,7 @@ impl TestFull { } = self; let TestData { - temp_file: _, + _temp_file: _, schema, file_name, file_size, @@ -361,9 +361,9 @@ impl TestFull { // Holds necessary data for these tests to reuse the same parquet file struct TestData { - // field is present as on drop the file is deleted - #[allow(dead_code)] - temp_file: NamedTempFile, + /// Pointer to temporary file storage. Keeping it in scope to prevent temporary folder + /// to be deleted prematurely + _temp_file: NamedTempFile, schema: SchemaRef, file_name: String, file_size: u64, @@ -402,7 +402,7 @@ fn get_test_data() -> &'static TestData { let file_size = temp_file.path().metadata().unwrap().len(); TestData { - temp_file, + _temp_file: temp_file, schema, file_name, file_size, diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index cfa2a3df3ba2..cd298d1c5543 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -100,10 +100,9 @@ enum Unit { /// table "t" registered, pointing at a parquet file made with /// `make_test_file` struct ContextWithParquet { - #[allow(dead_code)] /// temp file parquet data is written to. The file is cleaned up /// when dropped - file: NamedTempFile, + _file: NamedTempFile, provider: Arc, ctx: SessionContext, } @@ -217,7 +216,7 @@ impl ContextWithParquet { ctx.register_table("t", provider.clone()).unwrap(); Self { - file, + _file: file, provider, ctx, } diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index 38c259fcbdc8..c71071b8093c 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -139,7 +139,7 @@ impl DiskManager { let dir_index = thread_rng().gen_range(0..local_dirs.len()); Ok(RefCountedTempFile { - parent_temp_dir: Arc::clone(&local_dirs[dir_index]), + _parent_temp_dir: Arc::clone(&local_dirs[dir_index]), tempfile: Builder::new() .tempfile_in(local_dirs[dir_index].as_ref()) .map_err(DataFusionError::IoError)?, @@ -153,8 +153,7 @@ impl DiskManager { pub struct RefCountedTempFile { /// The reference to the directory in which temporary files are created to ensure /// it is not cleaned up prior to the NamedTempFile - #[allow(dead_code)] - parent_temp_dir: Arc, + _parent_temp_dir: Arc, tempfile: NamedTempFile, } diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 7f785006f755..f53fe13df15e 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -53,8 +53,7 @@ struct JoinLeftData { merged_batch: RecordBatch, /// Track memory reservation for merged_batch. Relies on drop /// semantics to release reservation when JoinLeftData is dropped. - #[allow(dead_code)] - reservation: MemoryReservation, + _reservation: MemoryReservation, } #[allow(rustdoc::private_intra_doc_links)] @@ -209,7 +208,7 @@ async fn load_left_input( Ok(JoinLeftData { merged_batch, - reservation, + _reservation: reservation, }) } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 32267b118193..8ab292c14269 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -92,8 +92,7 @@ struct JoinLeftData { probe_threads_counter: AtomicUsize, /// Memory reservation that tracks memory used by `hash_map` hash table /// `batch`. Cleared on drop. - #[allow(dead_code)] - reservation: MemoryReservation, + _reservation: MemoryReservation, } impl JoinLeftData { @@ -110,7 +109,7 @@ impl JoinLeftData { batch, visited_indices_bitmap, probe_threads_counter, - reservation, + _reservation: reservation, } } diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 71c617a96300..2beeb92da499 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -69,8 +69,7 @@ struct JoinLeftData { probe_threads_counter: AtomicUsize, /// Memory reservation for tracking batch and bitmap /// Cleared on `JoinLeftData` drop - #[allow(dead_code)] - reservation: MemoryReservation, + _reservation: MemoryReservation, } impl JoinLeftData { @@ -78,13 +77,13 @@ impl JoinLeftData { batch: RecordBatch, bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, - reservation: MemoryReservation, + _reservation: MemoryReservation, ) -> Self { Self { batch, bitmap, probe_threads_counter, - reservation, + _reservation, } } diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 1730c7d8dc61..0a80dcd34e05 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -623,7 +623,7 @@ impl ExecutionPlan for RepartitionExec { Box::pin(PerPartitionStream { schema: Arc::clone(&schema_captured), receiver, - drop_helper: Arc::clone(&abort_helper), + _drop_helper: Arc::clone(&abort_helper), reservation: Arc::clone(&reservation), }) as SendableRecordBatchStream }) @@ -651,7 +651,7 @@ impl ExecutionPlan for RepartitionExec { num_input_partitions_processed: 0, schema: input.schema(), input: rx.swap_remove(0), - drop_helper: abort_helper, + _drop_helper: abort_helper, reservation, }) as SendableRecordBatchStream) } @@ -906,8 +906,7 @@ struct RepartitionStream { input: DistributionReceiver, /// Handle to ensure background tasks are killed when no longer needed. - #[allow(dead_code)] - drop_helper: Arc>>, + _drop_helper: Arc>>, /// Memory reservation. reservation: SharedMemoryReservation, @@ -970,8 +969,7 @@ struct PerPartitionStream { receiver: DistributionReceiver, /// Handle to ensure background tasks are killed when no longer needed. - #[allow(dead_code)] - drop_helper: Arc>>, + _drop_helper: Arc>>, /// Memory reservation. reservation: SharedMemoryReservation, diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index 133d736c1467..5cd24b89f5c1 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -156,8 +156,7 @@ pub struct RowValues { /// Tracks for the memory used by in the `Rows` of this /// cursor. Freed on drop - #[allow(dead_code)] - reservation: MemoryReservation, + _reservation: MemoryReservation, } impl RowValues { @@ -173,7 +172,10 @@ impl RowValues { "memory reservation mismatch" ); assert!(rows.num_rows() > 0); - Self { rows, reservation } + Self { + rows, + _reservation: reservation, + } } }